diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 070262c68..87651d47e 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -87,6 +87,38 @@ struct OrReduce { } }; +struct MaxReduce { + template + std::enable_if_t> operator()(T* y, T x) { + (*y) = (*y > x) ? *y : x; + }; + + template + std::enable_if_t> operator()(T* y, T x) { + if (std::isnan(x)) { + *y = x; + } else { + (*y) = (*y > x) ? *y : x; + } + }; +}; + +struct MinReduce { + template + std::enable_if_t> operator()(T* y, T x) { + (*y) = (*y < x) ? *y : x; + }; + + template + std::enable_if_t> operator()(T* y, T x) { + if (std::isnan(x)) { + *y = x; + } else { + (*y) = (*y < x) ? *y : x; + } + }; +}; + template void reduce_dispatch_out( const array& in, @@ -118,15 +150,13 @@ void reduce_dispatch_out( break; } case Reduce::Max: { - auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; }; auto init = Limits::min; - reduction_op(in, out, axes, init, op); + reduction_op(in, out, axes, init, MaxReduce()); break; } case Reduce::Min: { - auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; }; auto init = Limits::max; - reduction_op(in, out, axes, init, op); + reduction_op(in, out, axes, init, MinReduce()); break; } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index aecf3a65c..9ffa693ad 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1378,6 +1378,10 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) { return logical_or(isposinf(a, s), isneginf(a, s), s); } +array isfinite(const array& a, StreamOrDevice s /* = {} */) { + return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s); +} + array isposinf(const array& a, StreamOrDevice s /* = {} */) { if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { return full(a.shape(), false, bool_, s); diff --git a/mlx/ops.h b/mlx/ops.h index 2a9c2f961..daff9bcdc 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -399,6 +399,8 @@ array isnan(const array& a, StreamOrDevice s = {}); array isinf(const array& a, StreamOrDevice s = {}); +array isfinite(const array& a, StreamOrDevice s = {}); + array isposinf(const array& a, StreamOrDevice s = {}); array isneginf(const array& a, StreamOrDevice s = {}); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e21bf525c..3e4aa1093 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1973,6 +1973,27 @@ void init_ops(nb::module_& m) { Returns: array: The boolean array indicating which elements are +/- infinity. )pbdoc"); + m.def( + "isfinite", + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::isfinite(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return a boolean array indicating which elements are finite. + + An element is finite if it is not infinite or NaN. + + Args: + a (array): Input array. + + Returns: + array: The boolean array indicating which elements are finite. + )pbdoc"); m.def( "isposinf", [](const ScalarOrArray& a, StreamOrDevice s) { @@ -4254,62 +4275,78 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "issubdtype", - nb::overload_cast(&issubdtype), + [](const nb::object& d1, const nb::object& d2) { + auto dispatch_second = [](const auto& t1, const auto& d2) { + if (nb::isinstance(d2)) { + return issubdtype(t1, nb::cast(d2)); + } else if (nb::isinstance(d2)) { + return issubdtype(t1, nb::cast(d2)); + } else { + throw std::invalid_argument( + "[issubdtype] Received invalid type for second input."); + } + }; + if (nb::isinstance(d1)) { + return dispatch_second(nb::cast(d1), d2); + } else if (nb::isinstance(d1)) { + return dispatch_second(nb::cast(d1), d2); + } else { + throw std::invalid_argument( + "[issubdtype] Received invalid type for first input."); + } + }, ""_a, ""_a, + nb::sig( + "def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"), R"pbdoc( Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype of another. - >>> ints = mx.array([1, 2, 3], dtype=mx.int32) - >>> mx.issubdtype(ints.dtype, mx.integer) - True - >>> mx.issubdtype(ints.dtype, mx.floating) - False + Args: + arg1 (Union[Dtype, DtypeCategory]: First dtype or category. + arg2 (Union[Dtype, DtypeCategory]: Second dtype or category. - >>> floats = mx.array([1, 2, 3], dtype=mx.float32) - >>> mx.issubdtype(floats.dtype, mx.integer) - False - >>> mx.issubdtype(floats.dtype, mx.floating) - True + Returns: + bool: + A boolean indicating if the first input is a subtype of the + second input. - Similar types of different sizes are not subdtypes of each other: + Example: - >>> mx.issubdtype(mx.float64, mx.float32) - False - >>> mx.issubdtype(mx.float32, mx.float64) - False + >>> ints = mx.array([1, 2, 3], dtype=mx.int32) + >>> mx.issubdtype(ints.dtype, mx.integer) + True + >>> mx.issubdtype(ints.dtype, mx.floating) + False - but both are subtypes of `floating`: + >>> floats = mx.array([1, 2, 3], dtype=mx.float32) + >>> mx.issubdtype(floats.dtype, mx.integer) + False + >>> mx.issubdtype(floats.dtype, mx.floating) + True - >>> mx.issubdtype(mx.float64, mx.floating) - True - >>> mx.issubdtype(mx.float32, mx.floating) - True + Similar types of different sizes are not subdtypes of each other: - For convenience, dtype-like objects are allowed too: + >>> mx.issubdtype(mx.float64, mx.float32) + False + >>> mx.issubdtype(mx.float32, mx.float64) + False - >>> mx.issubdtype(mx.float32, mx.inexact) - True - >>> mx.issubdtype(mx.signedinteger, mx.floating) - False + but both are subtypes of `floating`: + + >>> mx.issubdtype(mx.float64, mx.floating) + True + >>> mx.issubdtype(mx.float32, mx.floating) + True + + For convenience, dtype-like objects are allowed too: + + >>> mx.issubdtype(mx.float32, mx.inexact) + True + >>> mx.issubdtype(mx.signedinteger, mx.floating) + False )pbdoc"); - m.def( - "issubdtype", - nb::overload_cast(&issubdtype), - ""_a, - ""_a); - m.def( - "issubdtype", - nb::overload_cast(&issubdtype), - ""_a, - ""_a); - m.def( - "issubdtype", - nb::overload_cast( - &issubdtype), - ""_a, - ""_a); m.def( "bitwise_and", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 2b8cdfbb8..c17190126 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -374,6 +374,16 @@ class TestOps(mlx_tests.MLXTestCase): result = mx.isinf(x) self.assertEqual(result.tolist(), [False, False, False]) + def test_isfinite(self): + x = mx.array([0.0, float("inf"), float("nan")]) + self.assertEqual(mx.isfinite(x).tolist(), [True, False, False]) + + x = x.astype(mx.float16) + self.assertEqual(mx.isfinite(x).tolist(), [True, False, False]) + + x = x.astype(mx.bfloat16) + self.assertEqual(mx.isfinite(x).tolist(), [True, False, False]) + def test_tri(self): for shape in [[4], [4, 4], [2, 10]]: for diag in [-1, 0, 1, -2]: