diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 84291dac1..649724a34 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -52,6 +52,8 @@ Operations identity inner isnan + isposinf + isneginf isinf less less_equal diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 61983197a..87a7d5e96 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1088,6 +1088,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) { return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); } +array isposinf(const array& a, StreamOrDevice s) { + return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); +} + +array isneginf(const array& a, StreamOrDevice s) { + return equal(a, array(-std::numeric_limits::infinity(), a.dtype()), s); +} + array where( const array& condition, const array& x, diff --git a/mlx/ops.h b/mlx/ops.h index c8efa86fa..16f85e147 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -380,6 +380,10 @@ array isnan(const array& a, StreamOrDevice s = {}); array isinf(const array& a, StreamOrDevice s = {}); +array isposinf(const array& a, StreamOrDevice s = {}); + +array isneginf(const array& a, StreamOrDevice s = {}); + /** Select from x or y depending on condition. */ array where( const array& condition, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e8d69f03d..9db273b40 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1856,6 +1856,44 @@ void init_ops(py::module_& m) { Returns: array: The boolean array indicating which elements are +/- infinity. )pbdoc"); + m.def( + "isposinf", + &isposinf, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array + + Return a boolean array indicating which elements are positive infinity. + + Args: + a (array): Input array. + stream (Union[None, Stream, Device]): Optional stream or device. + + Returns: + array: The boolean array indicating which elements are positive infinity. + )pbdoc"); + m.def( + "isneginf", + &isneginf, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array + + Return a boolean array indicating which elements are negative infinity. + + Args: + a (array): Input array. + stream (Union[None, Stream, Device]): Optional stream or device. + + Returns: + array: The boolean array indicating which elements are negative infinity. + )pbdoc"); m.def( "moveaxis", &moveaxis, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 835621104..24f82d40b 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -401,6 +401,36 @@ class TestOps(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.ceil(mx.array([22 + 3j, 19 + 98j])) + def test_isposinf(self): + x = mx.array([0.0, float("-inf")]) + self.assertEqual(mx.isposinf(x).tolist(), [False, False]) + + x = mx.array([0.0, float("-inf")]).astype(mx.float16) + self.assertEqual(mx.isposinf(x).tolist(), [False, False]) + + x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16) + self.assertEqual(mx.isposinf(x).tolist(), [False, False]) + + x = mx.array([0.0, float("-inf")]).astype(mx.complex64) + self.assertEqual(mx.isposinf(x).tolist(), [False, False]) + + self.assertEqual(mx.isposinf(0 * mx.array(float("inf"))).tolist(), False) + + def test_isneginf(self): + x = mx.array([0.0, float("-inf")]) + self.assertEqual(mx.isneginf(x).tolist(), [False, True]) + + x = mx.array([0.0, float("-inf")]).astype(mx.float16) + self.assertEqual(mx.isneginf(x).tolist(), [False, True]) + + x = mx.array([0.0, float("-inf")]).astype(mx.bfloat16) + self.assertEqual(mx.isneginf(x).tolist(), [False, True]) + + x = mx.array([0.0, float("-inf")]).astype(mx.complex64) + self.assertEqual(mx.isneginf(x).tolist(), [False, True]) + + self.assertEqual(mx.isneginf(0 * mx.array(float("inf"))).tolist(), False) + def test_round(self): # float x = mx.array( diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 2c4bacbbf..2c4348554 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1880,6 +1880,52 @@ TEST_CASE("test scatter") { CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item()); } +TEST_CASE("test is positive infinity") { + array x(1.0f); + CHECK_FALSE(isposinf(x).item()); + + array y(std::numeric_limits::infinity()); + CHECK(isposinf(y).item()); + + array z = identity(7); + CHECK_FALSE(all(isposinf(z)).item()); + + array w = array({1.0f, std::numeric_limits::infinity(), 2.0f}); + CHECK_FALSE(all(isposinf(w)).item()); + + array a(1.0f, bfloat16); + CHECK_FALSE(isposinf(a).item()); + + array b(std::numeric_limits::infinity(), float16); + CHECK(isposinf(b).item()); + + array c(std::numeric_limits::infinity(), bfloat16); + CHECK(isposinf(c).item()); +} + +TEST_CASE("test is negative infinity") { + array x(1.0f); + CHECK_FALSE(isneginf(x).item()); + + array y(-std::numeric_limits::infinity()); + CHECK(isneginf(y).item()); + + array z = identity(7); + CHECK_FALSE(all(isneginf(z)).item()); + + array w = array({1.0f, -std::numeric_limits::infinity(), 2.0f}); + CHECK_FALSE(all(isneginf(w)).item()); + + array a(1.0f, bfloat16); + CHECK_FALSE(isneginf(a).item()); + + array b(-std::numeric_limits::infinity(), float16); + CHECK(isneginf(b).item()); + + array c(-std::numeric_limits::infinity(), bfloat16); + CHECK(isneginf(c).item()); +} + TEST_CASE("test scatter types") { for (auto t : {bool_, uint8, uint16, int8, int16}) { auto in = zeros({4, 4}, t);