mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	consistently handle all -inf in softmax (#1470)
This commit is contained in:
		| @@ -1567,6 +1567,11 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             out = mx.softmax(a, axis=-1, precise=True) | ||||
|             self.assertTrue(mx.allclose(out_expect, out)) | ||||
|  | ||||
|         # All Infs give NaNs | ||||
|         for n in [127, 128, 129]: | ||||
|             x = mx.full((n,), vals=-float("inf")) | ||||
|             self.assertTrue(mx.all(mx.isnan(mx.softmax(x)))) | ||||
|  | ||||
|     def test_concatenate(self): | ||||
|         a_npy = np.random.randn(32, 32, 32) | ||||
|         b_npy = np.random.randn(32, 32, 32) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun