Add mx.finfo and use it when making causal mask (#1726)

* finfo

* fixes

* docs
This commit is contained in:
Awni Hannun
2024-12-19 14:52:41 -08:00
committed by GitHub
parent e03f0372b1
commit c3628eea49
9 changed files with 154 additions and 3 deletions

View File

@@ -1826,6 +1826,15 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertGreater(cosine(y, yq).min(), 0.99)
def test_causal_mask(self):
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.float16)
self.assertFalse(mx.any(mx.isnan(mask)))
self.assertTrue(mask[0, -1].item() < 0)
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.bfloat16)
self.assertFalse(mx.any(mx.isnan(mask)))
self.assertTrue(mask[0, -1].item() < 0)
if __name__ == "__main__":
unittest.main()