Add mx.finfo and use it when making causal mask (#1726)

* finfo

* fixes

* docs
This commit is contained in:
Awni Hannun
2024-12-19 14:52:41 -08:00
committed by GitHub
parent e03f0372b1
commit c3628eea49
9 changed files with 154 additions and 3 deletions

View File

@@ -102,9 +102,7 @@ class MultiHeadAttention(Module):
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
# TODO: Should replace this with finfo(dtype).min
mask = mask.astype(dtype) * -1e9
mask = mask.astype(dtype) * mx.finfo(dtype).min
return mask