diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 538bc51c..b5fee238 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -47,7 +47,6 @@ def create_causal_mask( def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: - lengths = None window_size = None offset = 0 if cache is not None and cache[0] is not None: @@ -57,8 +56,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): window_size = c.max_size else: offset = c.offset - lengths = getattr(c, "lengths", None) - mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths) + mask = create_causal_mask(T, offset, window_size=window_size) mask = mask.astype(h.dtype) else: mask = None