mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add mx.finfo and use it when making causal mask (#1726)
				
					
				
			* finfo * fixes * docs
This commit is contained in:
		| @@ -97,6 +97,18 @@ class TestDtypes(mlx_tests.MLXTestCase): | ||||
|                 self.assertListEqual(list(z.shape), list(x.shape)) | ||||
|                 self.assertListEqual(list(z.shape), list(y.shape)) | ||||
|  | ||||
|     def test_finfo(self): | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.finfo(mx.int32) | ||||
|  | ||||
|         self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min) | ||||
|         self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max) | ||||
|         self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32) | ||||
|  | ||||
|         self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min) | ||||
|         self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) | ||||
|         self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) | ||||
|  | ||||
|  | ||||
| class TestEquality(mlx_tests.MLXTestCase): | ||||
|     def test_array_eq_array(self): | ||||
|   | ||||
| @@ -1826,6 +1826,15 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertGreater(cosine(y, yq).min(), 0.99) | ||||
|  | ||||
|     def test_causal_mask(self): | ||||
|         mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.float16) | ||||
|         self.assertFalse(mx.any(mx.isnan(mask))) | ||||
|         self.assertTrue(mask[0, -1].item() < 0) | ||||
|  | ||||
|         mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.bfloat16) | ||||
|         self.assertFalse(mx.any(mx.isnan(mask))) | ||||
|         self.assertTrue(mask[0, -1].item() < 0) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun