mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Propagate nans in binary ops (#579)
* propagate nans in binary ops * handle empty matmul * cpu minimum/maximum propagate nan * benchmark maximum * add min as well * throw on negative indices with full * verbose on linux * fix matmul for zero K
This commit is contained in:
		| @@ -386,6 +386,11 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         expected = [0, -7, 3] | ||||
|         self.assertListEqual(mx.minimum(x, y).tolist(), expected) | ||||
|  | ||||
|         a = mx.array([float("nan")]) | ||||
|         b = mx.array([0.0]) | ||||
|         self.assertTrue(math.isnan(mx.minimum(a, b).item())) | ||||
|         self.assertTrue(math.isnan(mx.minimum(b, a).item())) | ||||
|  | ||||
|     def test_maximum(self): | ||||
|         x = mx.array([0.0, -5, 10.0]) | ||||
|         y = mx.array([1.0, -7.0, 3.0]) | ||||
| @@ -393,6 +398,11 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         expected = [1, -5, 10] | ||||
|         self.assertListEqual(mx.maximum(x, y).tolist(), expected) | ||||
|  | ||||
|         a = mx.array([float("nan")]) | ||||
|         b = mx.array([0.0]) | ||||
|         self.assertTrue(math.isnan(mx.maximum(a, b).item())) | ||||
|         self.assertTrue(math.isnan(mx.maximum(b, a).item())) | ||||
|  | ||||
|     def test_floor(self): | ||||
|         x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]) | ||||
|         expected = [-23, 19, -27, 9, 0, -np.inf, np.inf] | ||||
| @@ -760,6 +770,10 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         a = mx.array([float("nan")]) | ||||
|         b = mx.array([0.0]) | ||||
|         self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) | ||||
|  | ||||
|     def test_log(self): | ||||
|         a = mx.array([1, 0.5, 10, 100]) | ||||
|         result = mx.log(a) | ||||
| @@ -1761,6 +1775,16 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         ) | ||||
|         self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile) | ||||
|  | ||||
|     def test_empty_matmuls(self): | ||||
|         a = mx.array([]) | ||||
|         b = mx.array([]) | ||||
|         self.assertEqual(mx.inner(a, b).item(), 0.0) | ||||
|  | ||||
|         a = mx.zeros((10, 0)) | ||||
|         b = mx.zeros((0, 10)) | ||||
|         out = a @ b | ||||
|         self.assertTrue(mx.array_equal(out, mx.zeros((10, 10)))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun