From c117af83b8cbec15523bd0d69e7a57f01237ca89 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 13 Jan 2025 10:22:32 -0800 Subject: [PATCH] fix gpt bigcode (#1204) --- llms/mlx_lm/models/gpt_bigcode.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 8415c59e..1d9794b6 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.wte(inputs) mask = None - if hidden_states.shape[1] > 1: - - position_ids = mx.array(np.arange(L)) - hidden_states += self.wpe(position_ids) - - if mask is None: - mask = create_attention_mask(hidden_states, cache) + if mask is not None and hidden_states.shape[1] > 1: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) + position_ids = mx.array(np.arange(L)) + else: + position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L)) + + hidden_states += self.wpe(position_ids) for layer, c in zip(self.h, cache): hidden_states = layer(hidden_states, mask, cache=c)