mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
use a bool mask for attention (#1319)
This commit is contained in:
parent
1bc3476a46
commit
65aa2ec849
@ -152,7 +152,7 @@ def setup_arg_parser():
|
||||
"--num-draft-tokens",
|
||||
type=int,
|
||||
help="Number of tokens to draft when using speculative decoding.",
|
||||
default=2,
|
||||
default=3,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
@ -33,13 +33,13 @@ def create_causal_mask(
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
linds = linds[:, None]
|
||||
rinds = rinds[None]
|
||||
mask = linds < rinds
|
||||
mask = linds >= rinds
|
||||
if window_size is not None:
|
||||
mask = mask | (linds > rinds + window_size)
|
||||
mask = mask & (linds <= rinds + window_size)
|
||||
if lengths is not None:
|
||||
lengths = lengths[:, None, None, None]
|
||||
mask = mask | (rinds >= lengths)
|
||||
return mask * -1e9
|
||||
mask = mask & (rinds < lengths)
|
||||
return mask
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
else:
|
||||
offset = c.offset
|
||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
return mask
|
||||
|
Loading…
Reference in New Issue
Block a user