diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 3dcd3660d..33a077346 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -51,6 +51,7 @@ Operations greater_equal identity inner + isnan less less_equal linspace diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 60a02bb86..5f7f60b99 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1050,6 +1050,10 @@ array array_equal( } } +array isnan(const array& a, StreamOrDevice s /* = {} */) { + return not_equal(a, a, s); +} + array where( const array& condition, const array& x, diff --git a/mlx/ops.h b/mlx/ops.h index 72534bf44..f095a06dd 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -374,6 +374,8 @@ array_equal(const array& a, const array& b, StreamOrDevice s = {}) { return array_equal(a, b, false, s); } +array isnan(const array& a, StreamOrDevice s = {}); + /** Select from x or y depending on condition. */ array where( const array& condition, diff --git a/python/src/array.cpp b/python/src/array.cpp index e9575af95..26f6e68ac 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1132,6 +1132,12 @@ void init_array(py::module_& m) { py::kw_only(), "stream"_a = none, "See :func:`any`.") + .def( + "isnan", + &mlx::core::isnan, + py::kw_only(), + "stream"_a = none, + "See :func:`isnan`.") .def( "moveaxis", &moveaxis, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2d60db6aa..9a47f0ed8 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1820,6 +1820,24 @@ void init_ops(py::module_& m) { Returns: array: The ceil of ``a``. )pbdoc"); + m.def( + "isnan", + &mlx::core::isnan, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + isnan(a: array, stream: Union[None, Stream, Device] = None) -> array + + Return a boolean array indicating which elements are NaN. + + Args: + a (array): Input array. + + Returns: + array: The array with boolean values indicating which elements are NaN. + )pbdoc"); m.def( "moveaxis", &moveaxis, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 433188b9a..3b889359c 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -321,6 +321,21 @@ class TestOps(mlx_tests.MLXTestCase): self.assertFalse(mx.array_equal(x, y)) self.assertTrue(mx.array_equal(x, y, equal_nan=True)) + def test_isnan(self): + x = mx.array([0.0, float("nan")]) + self.assertEqual(mx.isnan(x).tolist(), [False, True]) + + x = mx.array([0.0, float("nan")]).astype(mx.float16) + self.assertEqual(mx.isnan(x).tolist(), [False, True]) + + x = mx.array([0.0, float("nan")]).astype(mx.bfloat16) + self.assertEqual(mx.isnan(x).tolist(), [False, True]) + + x = mx.array([0.0, float("nan")]).astype(mx.complex64) + self.assertEqual(mx.isnan(x).tolist(), [False, True]) + + self.assertEqual(mx.isnan(0 * mx.array(float("inf"))).tolist(), True) + def test_tri(self): for shape in [[4], [4, 4], [2, 10]]: for diag in [-1, 0, 1, -2]: diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 70b4c82e4..d8d76b94f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -479,6 +479,32 @@ TEST_CASE("test comparison ops") { } } +TEST_CASE("test is nan") { + array x(1.0f); + CHECK_FALSE(isnan(x).item()); + + array y(NAN); + CHECK(isnan(y).item()); + + array z = identity(7); + CHECK_FALSE(all(isnan(z)).item()); + + array w = array({1.0f, NAN, 2.0f}); + CHECK_FALSE(all(isnan(w)).item()); + + array a(1.0f, bfloat16); + CHECK_FALSE(isnan(a).item()); + + array b(1.0f, float16); + CHECK_FALSE(isnan(b).item()); + + array c(NAN, bfloat16); + CHECK(isnan(c).item()); + + array d(NAN, float16); + CHECK(isnan(d).item()); +} + TEST_CASE("test all close") { array x(1.0f); array y(1.0f);