From ef895f6e5bc7367bccaff3879e9f27b40e1b4ecd Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 18 Dec 2024 13:55:28 -0800 Subject: [PATCH] remove lengths --- llms/mlx_lm/models/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 538bc51c..b5fee238 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -47,7 +47,6 @@ def create_causal_mask( def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: - lengths = None window_size = None offset = 0 if cache is not None and cache[0] is not None: @@ -57,8 +56,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): window_size = c.max_size else: offset = c.offset - lengths = getattr(c, "lengths", None) - mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths) + mask = create_causal_mask(T, offset, window_size=window_size) mask = mask.astype(h.dtype) else: mask = None