From bf2da36fc640e6bfab933ac8c10d76c86fcdb288 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 12 Jan 2025 21:58:08 +0100 Subject: [PATCH] Fix Cohere2: mask shape error (long context) (#1202) * fix mask shape error (long context) * Update llms/mlx_lm/models/cohere2.py Co-authored-by: Awni Hannun * revert layer_idx * black formatting * Update cohere2.py * format --------- Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- llms/mlx_lm/models/cohere2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index ec0e9276..19bfa6b6 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -156,12 +156,13 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - if mask is None: - mask = create_attention_mask(h, cache) - if cache is None: cache = [None] * len(self.layers) + if mask is None: + j = self.args.sliding_window_pattern + mask = create_attention_mask(h, cache[j - 1 : j]) + for layer, c in zip(self.layers, cache): h = layer(h, mask, c)