Adds isinf (#445)

* adds isinf

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>

* use stream + nits

* typo

---------

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Matthew Ernst
2024-01-15 19:50:44 -08:00
committed by GitHub
parent 6022d4129e
commit 92a2fdd577
7 changed files with 67 additions and 7 deletions

View File

@@ -336,6 +336,21 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.isnan(0 * mx.array(float("inf"))).tolist(), True)
def test_isinf(self):
x = mx.array([0.0, float("inf")])
self.assertEqual(mx.isinf(x).tolist(), [False, True])
x = mx.array([0.0, float("inf")]).astype(mx.float16)
self.assertEqual(mx.isinf(x).tolist(), [False, True])
x = mx.array([0.0, float("inf")]).astype(mx.bfloat16)
self.assertEqual(mx.isinf(x).tolist(), [False, True])
x = mx.array([0.0, float("inf")]).astype(mx.complex64)
self.assertEqual(mx.isinf(x).tolist(), [False, True])
self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False)
def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]: