use a bool mask for attention (#1319)

This commit is contained in:
Awni Hannun 2025-03-04 12:47:32 -08:00 committed by GitHub
parent 1bc3476a46
commit 65aa2ec849
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View File

@ -152,7 +152,7 @@ def setup_arg_parser():
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
default=3,
)
return parser

View File

@ -33,13 +33,13 @@ def create_causal_mask(
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
mask = linds >= rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
mask = mask & (linds <= rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9
mask = mask & (rinds < lengths)
return mask
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None
return mask