mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values * Added 'equal_nan' to match numpy * format * Add test * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Addressed CR comments * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * nits --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -855,6 +855,21 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         self.assertFalse(mx.allclose(a, b, 0.01).item()) | ||||
|         self.assertTrue(mx.allclose(a, b, 0.01, 0.1).item()) | ||||
|  | ||||
|         c = mx.array(float("inf")) | ||||
|         self.assertTrue(mx.allclose(c, c).item()) | ||||
|  | ||||
|     def test_isclose(self): | ||||
|         a = mx.array([float("inf"), float("inf"), float("-inf")]) | ||||
|         b = mx.array([float("inf"), float("-inf"), float("-inf")]) | ||||
|  | ||||
|         self.assertListEqual(mx.isclose(a, b).tolist(), [True, False, True]) | ||||
|  | ||||
|         a = mx.array([np.nan]) | ||||
|         self.assertListEqual(mx.isclose(a, a).tolist(), [False]) | ||||
|  | ||||
|         a = mx.array([np.nan]) | ||||
|         self.assertListEqual(mx.isclose(a, a, equal_nan=True).tolist(), [True]) | ||||
|  | ||||
|     def test_all(self): | ||||
|         a = mx.array([[True, False], [True, True]]) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Rifur13
					Rifur13