From 9d7f38abb8dc7d464bea2ad382f15e3bc4030210 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 12 Jan 2025 15:25:02 +0100 Subject: [PATCH] revert layer_idx --- llms/mlx_lm/models/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 1352ddf3..ad7a4a65 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,13 +42,13 @@ def create_causal_mask( return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None, layer_idx: int = 0): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: window_size = None offset = 0 - if cache is not None and cache[layer_idx] is not None: - c = cache[layer_idx] + if cache is not None and cache[0] is not None: + c = cache[0] if hasattr(c, "max_size"): offset = min(c.max_size, c.offset) window_size = c.max_size