remove lengths

This commit is contained in:
Alex Barron 2024-12-18 13:55:28 -08:00
parent cd9dcf0383
commit ef895f6e5b

View File

@ -47,7 +47,6 @@ def create_causal_mask(
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
lengths = None
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
@ -57,8 +56,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
window_size = c.max_size
else:
offset = c.offset
lengths = getattr(c, "lengths", None)
mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths)
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None