mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Add isfinite
(#1318)
* isfinite * remove reduce test since fix is not complete
This commit is contained in:
parent
a098bc92e0
commit
eaaea02010
@ -87,6 +87,38 @@ struct OrReduce {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MaxReduce {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||||
|
(*y) = (*y > x) ? *y : x;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
*y = x;
|
||||||
|
} else {
|
||||||
|
(*y) = (*y > x) ? *y : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MinReduce {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||||
|
(*y) = (*y < x) ? *y : x;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
*y = x;
|
||||||
|
} else {
|
||||||
|
(*y) = (*y < x) ? *y : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
template <typename InT>
|
template <typename InT>
|
||||||
void reduce_dispatch_out(
|
void reduce_dispatch_out(
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -118,15 +150,13 @@ void reduce_dispatch_out(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Reduce::Max: {
|
case Reduce::Max: {
|
||||||
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
|
||||||
auto init = Limits<InT>::min;
|
auto init = Limits<InT>::min;
|
||||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Reduce::Min: {
|
case Reduce::Min: {
|
||||||
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
|
||||||
auto init = Limits<InT>::max;
|
auto init = Limits<InT>::max;
|
||||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1378,6 +1378,10 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return logical_or(isposinf(a, s), isneginf(a, s), s);
|
return logical_or(isposinf(a, s), isneginf(a, s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array isfinite(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);
|
||||||
|
}
|
||||||
|
|
||||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||||
return full(a.shape(), false, bool_, s);
|
return full(a.shape(), false, bool_, s);
|
||||||
|
@ -399,6 +399,8 @@ array isnan(const array& a, StreamOrDevice s = {});
|
|||||||
|
|
||||||
array isinf(const array& a, StreamOrDevice s = {});
|
array isinf(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array isfinite(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
array isposinf(const array& a, StreamOrDevice s = {});
|
array isposinf(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
array isneginf(const array& a, StreamOrDevice s = {});
|
array isneginf(const array& a, StreamOrDevice s = {});
|
||||||
|
@ -1973,6 +1973,27 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The boolean array indicating which elements are +/- infinity.
|
array: The boolean array indicating which elements are +/- infinity.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"isfinite",
|
||||||
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::isfinite(to_array(a), s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Return a boolean array indicating which elements are finite.
|
||||||
|
|
||||||
|
An element is finite if it is not infinite or NaN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The boolean array indicating which elements are finite.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"isposinf",
|
"isposinf",
|
||||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
@ -4254,13 +4275,45 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"issubdtype",
|
"issubdtype",
|
||||||
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
|
[](const nb::object& d1, const nb::object& d2) {
|
||||||
|
auto dispatch_second = [](const auto& t1, const auto& d2) {
|
||||||
|
if (nb::isinstance<Dtype>(d2)) {
|
||||||
|
return issubdtype(t1, nb::cast<Dtype>(d2));
|
||||||
|
} else if (nb::isinstance<Dtype::Category>(d2)) {
|
||||||
|
return issubdtype(t1, nb::cast<Dtype::Category>(d2));
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[issubdtype] Received invalid type for second input.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (nb::isinstance<Dtype>(d1)) {
|
||||||
|
return dispatch_second(nb::cast<Dtype>(d1), d2);
|
||||||
|
} else if (nb::isinstance<Dtype::Category>(d1)) {
|
||||||
|
return dispatch_second(nb::cast<Dtype::Category>(d1), d2);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[issubdtype] Received invalid type for first input.");
|
||||||
|
}
|
||||||
|
},
|
||||||
""_a,
|
""_a,
|
||||||
""_a,
|
""_a,
|
||||||
|
nb::sig(
|
||||||
|
"def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
||||||
of another.
|
of another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg1 (Union[Dtype, DtypeCategory]: First dtype or category.
|
||||||
|
arg2 (Union[Dtype, DtypeCategory]: Second dtype or category.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool:
|
||||||
|
A boolean indicating if the first input is a subtype of the
|
||||||
|
second input.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||||
>>> mx.issubdtype(ints.dtype, mx.integer)
|
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||||
True
|
True
|
||||||
@ -4294,22 +4347,6 @@ void init_ops(nb::module_& m) {
|
|||||||
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||||
False
|
False
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
|
||||||
"issubdtype",
|
|
||||||
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
|
|
||||||
""_a,
|
|
||||||
""_a);
|
|
||||||
m.def(
|
|
||||||
"issubdtype",
|
|
||||||
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
|
|
||||||
""_a,
|
|
||||||
""_a);
|
|
||||||
m.def(
|
|
||||||
"issubdtype",
|
|
||||||
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
|
|
||||||
&issubdtype),
|
|
||||||
""_a,
|
|
||||||
""_a);
|
|
||||||
m.def(
|
m.def(
|
||||||
"bitwise_and",
|
"bitwise_and",
|
||||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||||
|
@ -374,6 +374,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
result = mx.isinf(x)
|
result = mx.isinf(x)
|
||||||
self.assertEqual(result.tolist(), [False, False, False])
|
self.assertEqual(result.tolist(), [False, False, False])
|
||||||
|
|
||||||
|
def test_isfinite(self):
|
||||||
|
x = mx.array([0.0, float("inf"), float("nan")])
|
||||||
|
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
|
||||||
|
|
||||||
|
x = x.astype(mx.float16)
|
||||||
|
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
|
||||||
|
|
||||||
|
x = x.astype(mx.bfloat16)
|
||||||
|
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
|
||||||
|
|
||||||
def test_tri(self):
|
def test_tri(self):
|
||||||
for shape in [[4], [4, 4], [2, 10]]:
|
for shape in [[4], [4, 4], [2, 10]]:
|
||||||
for diag in [-1, 0, 1, -2]:
|
for diag in [-1, 0, 1, -2]:
|
||||||
|
Loading…
Reference in New Issue
Block a user