fix gpt bigcode (#1204)

This commit is contained in:
Awni Hannun
2025-01-13 10:22:32 -08:00
committed by GitHub
parent 0228c46434
commit c117af83b8

View File

@@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if mask is not None and hidden_states.shape[1] > 1:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
position_ids = mx.array(np.arange(L))
else:
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
hidden_states += self.wpe(position_ids)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)