fix gpt bigcode (#1204)

This commit is contained in:
Awni Hannun 2025-01-13 10:22:32 -08:00 committed by GitHub
parent 0228c46434
commit c117af83b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.wte(inputs) hidden_states = self.wte(inputs)
mask = None mask = None
if hidden_states.shape[1] > 1: if mask is not None and hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache) mask = create_attention_mask(hidden_states, cache)
if cache is None: if cache is None:
cache = [None] * len(self.h) cache = [None] * len(self.h)
position_ids = mx.array(np.arange(L))
else:
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
hidden_states += self.wpe(position_ids)
for layer, c in zip(self.h, cache): for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c) hidden_states = layer(hidden_states, mask, cache=c)