Add isfinite (#1318)

* isfinite

* remove reduce test since fix is not complete
This commit is contained in:
Awni Hannun 2024-08-13 14:49:28 -07:00 committed by GitHub
parent a098bc92e0
commit eaaea02010
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 129 additions and 46 deletions

View File

@ -87,6 +87,38 @@ struct OrReduce {
}
};
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
};
};
template <typename InT>
void reduce_dispatch_out(
const array& in,
@ -118,15 +150,13 @@ void reduce_dispatch_out(
break;
}
case Reduce::Max: {
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, op);
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, op);
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
}
}

View File

@ -1378,6 +1378,10 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
return logical_or(isposinf(a, s), isneginf(a, s), s);
}
array isfinite(const array& a, StreamOrDevice s /* = {} */) {
return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);

View File

@ -399,6 +399,8 @@ array isnan(const array& a, StreamOrDevice s = {});
array isinf(const array& a, StreamOrDevice s = {});
array isfinite(const array& a, StreamOrDevice s = {});
array isposinf(const array& a, StreamOrDevice s = {});
array isneginf(const array& a, StreamOrDevice s = {});

View File

@ -1973,6 +1973,27 @@ void init_ops(nb::module_& m) {
Returns:
array: The boolean array indicating which elements are +/- infinity.
)pbdoc");
m.def(
"isfinite",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::isfinite(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return a boolean array indicating which elements are finite.
An element is finite if it is not infinite or NaN.
Args:
a (array): Input array.
Returns:
array: The boolean array indicating which elements are finite.
)pbdoc");
m.def(
"isposinf",
[](const ScalarOrArray& a, StreamOrDevice s) {
@ -4254,62 +4275,78 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
[](const nb::object& d1, const nb::object& d2) {
auto dispatch_second = [](const auto& t1, const auto& d2) {
if (nb::isinstance<Dtype>(d2)) {
return issubdtype(t1, nb::cast<Dtype>(d2));
} else if (nb::isinstance<Dtype::Category>(d2)) {
return issubdtype(t1, nb::cast<Dtype::Category>(d2));
} else {
throw std::invalid_argument(
"[issubdtype] Received invalid type for second input.");
}
};
if (nb::isinstance<Dtype>(d1)) {
return dispatch_second(nb::cast<Dtype>(d1), d2);
} else if (nb::isinstance<Dtype::Category>(d1)) {
return dispatch_second(nb::cast<Dtype::Category>(d1), d2);
} else {
throw std::invalid_argument(
"[issubdtype] Received invalid type for first input.");
}
},
""_a,
""_a,
nb::sig(
"def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"),
R"pbdoc(
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
of another.
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
>>> mx.issubdtype(ints.dtype, mx.integer)
True
>>> mx.issubdtype(ints.dtype, mx.floating)
False
Args:
arg1 (Union[Dtype, DtypeCategory]: First dtype or category.
arg2 (Union[Dtype, DtypeCategory]: Second dtype or category.
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
>>> mx.issubdtype(floats.dtype, mx.integer)
False
>>> mx.issubdtype(floats.dtype, mx.floating)
True
Returns:
bool:
A boolean indicating if the first input is a subtype of the
second input.
Similar types of different sizes are not subdtypes of each other:
Example:
>>> mx.issubdtype(mx.float64, mx.float32)
False
>>> mx.issubdtype(mx.float32, mx.float64)
False
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
>>> mx.issubdtype(ints.dtype, mx.integer)
True
>>> mx.issubdtype(ints.dtype, mx.floating)
False
but both are subtypes of `floating`:
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
>>> mx.issubdtype(floats.dtype, mx.integer)
False
>>> mx.issubdtype(floats.dtype, mx.floating)
True
>>> mx.issubdtype(mx.float64, mx.floating)
True
>>> mx.issubdtype(mx.float32, mx.floating)
True
Similar types of different sizes are not subdtypes of each other:
For convenience, dtype-like objects are allowed too:
>>> mx.issubdtype(mx.float64, mx.float32)
False
>>> mx.issubdtype(mx.float32, mx.float64)
False
>>> mx.issubdtype(mx.float32, mx.inexact)
True
>>> mx.issubdtype(mx.signedinteger, mx.floating)
False
but both are subtypes of `floating`:
>>> mx.issubdtype(mx.float64, mx.floating)
True
>>> mx.issubdtype(mx.float32, mx.floating)
True
For convenience, dtype-like objects are allowed too:
>>> mx.issubdtype(mx.float32, mx.inexact)
True
>>> mx.issubdtype(mx.signedinteger, mx.floating)
False
)pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
&issubdtype),
""_a,
""_a);
m.def(
"bitwise_and",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {

View File

@ -374,6 +374,16 @@ class TestOps(mlx_tests.MLXTestCase):
result = mx.isinf(x)
self.assertEqual(result.tolist(), [False, False, False])
def test_isfinite(self):
x = mx.array([0.0, float("inf"), float("nan")])
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
x = x.astype(mx.float16)
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
x = x.astype(mx.bfloat16)
self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])
def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]: