mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:04 +08:00
remove lengths
This commit is contained in:
parent
cd9dcf0383
commit
ef895f6e5b
@ -47,7 +47,6 @@ def create_causal_mask(
|
|||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
T = h.shape[1]
|
T = h.shape[1]
|
||||||
if T > 1:
|
if T > 1:
|
||||||
lengths = None
|
|
||||||
window_size = None
|
window_size = None
|
||||||
offset = 0
|
offset = 0
|
||||||
if cache is not None and cache[0] is not None:
|
if cache is not None and cache[0] is not None:
|
||||||
@ -57,8 +56,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
|||||||
window_size = c.max_size
|
window_size = c.max_size
|
||||||
else:
|
else:
|
||||||
offset = c.offset
|
offset = c.offset
|
||||||
lengths = getattr(c, "lengths", None)
|
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||||
mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths)
|
|
||||||
mask = mask.astype(h.dtype)
|
mask = mask.astype(h.dtype)
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
|
Loading…
Reference in New Issue
Block a user