mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
remove lengths
This commit is contained in:
parent
cd9dcf0383
commit
ef895f6e5b
@ -47,7 +47,6 @@ def create_causal_mask(
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
lengths = None
|
||||
window_size = None
|
||||
offset = 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
@ -57,8 +56,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
lengths = getattr(c, "lengths", None)
|
||||
mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths)
|
||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
Loading…
Reference in New Issue
Block a user