diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 8415c59e..1d9794b6 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.wte(inputs) mask = None - if hidden_states.shape[1] > 1: - - position_ids = mx.array(np.arange(L)) - hidden_states += self.wpe(position_ids) - - if mask is None: - mask = create_attention_mask(hidden_states, cache) + if mask is not None and hidden_states.shape[1] > 1: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) + position_ids = mx.array(np.arange(L)) + else: + position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L)) + + hidden_states += self.wpe(position_ids) for layer, c in zip(self.h, cache): hidden_states = layer(hidden_states, mask, cache=c)