mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix gpt bigcode (#1204)
This commit is contained in:
parent
0228c46434
commit
c117af83b8
@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
hidden_states = self.wte(inputs)
|
hidden_states = self.wte(inputs)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if hidden_states.shape[1] > 1:
|
if mask is not None and hidden_states.shape[1] > 1:
|
||||||
|
|
||||||
position_ids = mx.array(np.arange(L))
|
|
||||||
hidden_states += self.wpe(position_ids)
|
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = create_attention_mask(hidden_states, cache)
|
mask = create_attention_mask(hidden_states, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
|
position_ids = mx.array(np.arange(L))
|
||||||
|
else:
|
||||||
|
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
|
||||||
|
|
||||||
|
hidden_states += self.wpe(position_ids)
|
||||||
|
|
||||||
for layer, c in zip(self.h, cache):
|
for layer, c in zip(self.h, cache):
|
||||||
hidden_states = layer(hidden_states, mask, cache=c)
|
hidden_states = layer(hidden_states, mask, cache=c)
|
||||||
|
Loading…
Reference in New Issue
Block a user