From 1107364c3ac8c81fe5bf8912804c05a5fbca51e6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 11 Jan 2025 21:56:00 +0100 Subject: [PATCH] fix mask shape error (long context) --- llms/mlx_lm/models/base.py | 6 +++--- llms/mlx_lm/models/cohere2.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65..1352ddf3 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,13 +42,13 @@ def create_causal_mask( return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None, layer_idx: int = 0): T = h.shape[1] if T > 1: window_size = None offset = 0 - if cache is not None and cache[0] is not None: - c = cache[0] + if cache is not None and cache[layer_idx] is not None: + c = cache[layer_idx] if hasattr(c, "max_size"): offset = min(c.max_size, c.offset) window_size = c.max_size diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index ec0e9276..b1c9f2da 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -157,7 +157,7 @@ class CohereModel(nn.Module): h = self.embed_tokens(inputs) if mask is None: - mask = create_attention_mask(h, cache) + mask = create_attention_mask(h, cache, layer_idx=self.args.sliding_window_pattern - 1) if cache is None: cache = [None] * len(self.layers)