diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 33a077346..84291dac1 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -52,6 +52,7 @@ Operations identity inner isnan + isinf less less_equal linspace diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2112a2a4d..c185568ff 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1084,6 +1084,10 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) { return not_equal(a, a, s); } +array isinf(const array& a, StreamOrDevice s /* = {} */) { + return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); +} + array where( const array& condition, const array& x, diff --git a/mlx/ops.h b/mlx/ops.h index f0823ed5f..c8efa86fa 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -378,6 +378,8 @@ array_equal(const array& a, const array& b, StreamOrDevice s = {}) { array isnan(const array& a, StreamOrDevice s = {}); +array isinf(const array& a, StreamOrDevice s = {}); + /** Select from x or y depending on condition. */ array where( const array& condition, diff --git a/python/src/array.cpp b/python/src/array.cpp index f7dfea90a..8f6e1ac4f 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1134,12 +1134,6 @@ void init_array(py::module_& m) { py::kw_only(), "stream"_a = none, "See :func:`any`.") - .def( - "isnan", - &mlx::core::isnan, - py::kw_only(), - "stream"_a = none, - "See :func:`isnan`.") .def( "moveaxis", &moveaxis, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c9c29532c..e8d69f03d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1836,7 +1836,25 @@ void init_ops(py::module_& m) { a (array): Input array. Returns: - array: The array with boolean values indicating which elements are NaN. + array: The boolean array indicating which elements are NaN. + )pbdoc"); + m.def( + "isinf", + &mlx::core::isinf, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + isinf(a: array, stream: Union[None, Stream, Device] = None) -> array + + Return a boolean array indicating which elements are +/- inifnity. + + Args: + a (array): Input array. + + Returns: + array: The boolean array indicating which elements are +/- infinity. )pbdoc"); m.def( "moveaxis", diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 84135427f..835621104 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -336,6 +336,21 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.isnan(0 * mx.array(float("inf"))).tolist(), True) + def test_isinf(self): + x = mx.array([0.0, float("inf")]) + self.assertEqual(mx.isinf(x).tolist(), [False, True]) + + x = mx.array([0.0, float("inf")]).astype(mx.float16) + self.assertEqual(mx.isinf(x).tolist(), [False, True]) + + x = mx.array([0.0, float("inf")]).astype(mx.bfloat16) + self.assertEqual(mx.isinf(x).tolist(), [False, True]) + + x = mx.array([0.0, float("inf")]).astype(mx.complex64) + self.assertEqual(mx.isinf(x).tolist(), [False, True]) + + self.assertEqual(mx.isinf(0 * mx.array(float("inf"))).tolist(), False) + def test_tri(self): for shape in [[4], [4, 4], [2, 10]]: for diag in [-1, 0, 1, -2]: diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 84c35ff2f..b2be73517 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -505,6 +505,32 @@ TEST_CASE("test is nan") { CHECK(isnan(d).item()); } +TEST_CASE("test is inf") { + array x(1.0f); + CHECK_FALSE(isinf(x).item()); + + array y(std::numeric_limits::infinity()); + CHECK(isinf(y).item()); + + array z = identity(7); + CHECK_FALSE(any(isinf(z)).item()); + + array w = array({1.0f, std::numeric_limits::infinity(), 2.0f}); + CHECK(array_equal({false, true, false}, isinf(w)).item()); + + array a(1.0f, bfloat16); + CHECK_FALSE(isinf(a).item()); + + array b(1.0f, float16); + CHECK_FALSE(isinf(b).item()); + + array c(std::numeric_limits::infinity(), bfloat16); + CHECK(isinf(c).item()); + + array d(std::numeric_limits::infinity(), float16); + CHECK(isinf(d).item()); +} + TEST_CASE("test all close") { array x(1.0f); array y(1.0f);