From 0e908bddff8cd0a4842b286d4c1a62bba8bde0bf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 12 Jan 2025 10:58:18 -0800 Subject: [PATCH] Update cohere2.py --- llms/mlx_lm/models/cohere2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index 784bcfcf..d489d2b9 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -156,13 +156,13 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) + if cache is None: + cache = [None] * len(self.layers) + if mask is None: j = self.args.sliding_window_pattern mask = create_attention_mask(h, cache[j - 1 : j]) - if cache is None: - cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): h = layer(h, mask, c)