mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add isfinite
(#1318)
* isfinite * remove reduce test since fix is not complete
This commit is contained in:
parent
a098bc92e0
commit
eaaea02010
@ -87,6 +87,38 @@ struct OrReduce {
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_out(
|
||||
const array& in,
|
||||
@ -118,15 +150,13 @@ void reduce_dispatch_out(
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -1378,6 +1378,10 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return logical_or(isposinf(a, s), isneginf(a, s), s);
|
||||
}
|
||||
|
||||
array isfinite(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
|
@ -399,6 +399,8 @@ array isnan(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isinf(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isfinite(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isneginf(const array& a, StreamOrDevice s = {});
|
||||
|
@ -1973,6 +1973,27 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are +/- infinity.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"isfinite",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
return mlx::core::isfinite(to_array(a), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return a boolean array indicating which elements are finite.
|
||||
|
||||
An element is finite if it is not infinite or NaN.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are finite.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"isposinf",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
@ -4254,13 +4275,45 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
|
||||
[](const nb::object& d1, const nb::object& d2) {
|
||||
auto dispatch_second = [](const auto& t1, const auto& d2) {
|
||||
if (nb::isinstance<Dtype>(d2)) {
|
||||
return issubdtype(t1, nb::cast<Dtype>(d2));
|
||||
} else if (nb::isinstance<Dtype::Category>(d2)) {
|
||||
return issubdtype(t1, nb::cast<Dtype::Category>(d2));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[issubdtype] Received invalid type for second input.");
|
||||
}
|
||||
};
|
||||
if (nb::isinstance<Dtype>(d1)) {
|
||||
return dispatch_second(nb::cast<Dtype>(d1), d2);
|
||||
} else if (nb::isinstance<Dtype::Category>(d1)) {
|
||||
return dispatch_second(nb::cast<Dtype::Category>(d1), d2);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[issubdtype] Received invalid type for first input.");
|
||||
}
|
||||
},
|
||||
""_a,
|
||||
""_a,
|
||||
nb::sig(
|
||||
"def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"),
|
||||
R"pbdoc(
|
||||
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
||||
of another.
|
||||
|
||||
Args:
|
||||
arg1 (Union[Dtype, DtypeCategory]: First dtype or category.
|
||||
arg2 (Union[Dtype, DtypeCategory]: Second dtype or category.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
A boolean indicating if the first input is a subtype of the
|
||||
second input.
|
||||
|
||||
Example:
|
||||
|
||||
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||
True
|
||||
@ -4294,22 +4347,6 @@ void init_ops(nb::module_& m) {
|
||||
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||
False
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
|
||||
&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"bitwise_and",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
@ -374,6 +374,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
result = mx.isinf(x)
|
||||
self.assertEqual(result.tolist(), [False, False, False])
|
||||
|
||||
def test_isfinite(self):
|
||||
x = mx.array([0.0, float("inf"), float("nan")])
|
||||
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
|
||||
|
||||
x = x.astype(mx.float16)
|
||||
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
|
||||
|
||||
x = x.astype(mx.bfloat16)
|
||||
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
|
||||
|
||||
def test_tri(self):
|
||||
for shape in [[4], [4, 4], [2, 10]]:
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
|
Loading…
Reference in New Issue
Block a user