mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Add mx.finfo
and use it when making causal mask (#1726)
* finfo * fixes * docs
This commit is contained in:
@@ -1826,6 +1826,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertGreater(cosine(y, yq).min(), 0.99)
|
||||
|
||||
def test_causal_mask(self):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.float16)
|
||||
self.assertFalse(mx.any(mx.isnan(mask)))
|
||||
self.assertTrue(mask[0, -1].item() < 0)
|
||||
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.bfloat16)
|
||||
self.assertFalse(mx.any(mx.isnan(mask)))
|
||||
self.assertTrue(mask[0, -1].item() < 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user