[Fix] mx.allclose bug with infinite values (#539)

* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Rifur13
2024-01-25 23:47:06 -05:00
committed by GitHub
parent 87b7fa9ba2
commit 2463496471
5 changed files with 143 additions and 11 deletions

View File

@@ -855,6 +855,21 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertFalse(mx.allclose(a, b, 0.01).item())
self.assertTrue(mx.allclose(a, b, 0.01, 0.1).item())
c = mx.array(float("inf"))
self.assertTrue(mx.allclose(c, c).item())
def test_isclose(self):
a = mx.array([float("inf"), float("inf"), float("-inf")])
b = mx.array([float("inf"), float("-inf"), float("-inf")])
self.assertListEqual(mx.isclose(a, b).tolist(), [True, False, True])
a = mx.array([np.nan])
self.assertListEqual(mx.isclose(a, a).tolist(), [False])
a = mx.array([np.nan])
self.assertListEqual(mx.isclose(a, a, equal_nan=True).tolist(), [True])
def test_all(self):
a = mx.array([[True, False], [True, True]])