diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 1352ddf3..ad7a4a65 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,13 +42,13 @@ def create_causal_mask( return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None, layer_idx: int = 0): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: window_size = None offset = 0 - if cache is not None and cache[layer_idx] is not None: - c = cache[layer_idx] + if cache is not None and cache[0] is not None: + c = cache[0] if hasattr(c, "max_size"): offset = min(c.max_size, c.offset) window_size = c.max_size