mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Add mx.finfo
and use it when making causal mask (#1726)
* finfo * fixes * docs
This commit is contained in:
@@ -102,9 +102,7 @@ class MultiHeadAttention(Module):
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||
# TODO: Should replace this with finfo(dtype).min
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
mask = mask.astype(dtype) * mx.finfo(dtype).min
|
||||
return mask
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user