Add isfinite (#1318)

* isfinite

* remove reduce test since fix is not complete
This commit is contained in:
Awni Hannun
2024-08-13 14:49:28 -07:00
committed by GitHub
parent a098bc92e0
commit eaaea02010
5 changed files with 129 additions and 46 deletions

View File

@@ -1973,6 +1973,27 @@ void init_ops(nb::module_& m) {
Returns:
array: The boolean array indicating which elements are +/- infinity.
)pbdoc");
m.def(
"isfinite",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::isfinite(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return a boolean array indicating which elements are finite.
An element is finite if it is not infinite or NaN.
Args:
a (array): Input array.
Returns:
array: The boolean array indicating which elements are finite.
)pbdoc");
m.def(
"isposinf",
[](const ScalarOrArray& a, StreamOrDevice s) {
@@ -4254,62 +4275,78 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
[](const nb::object& d1, const nb::object& d2) {
auto dispatch_second = [](const auto& t1, const auto& d2) {
if (nb::isinstance<Dtype>(d2)) {
return issubdtype(t1, nb::cast<Dtype>(d2));
} else if (nb::isinstance<Dtype::Category>(d2)) {
return issubdtype(t1, nb::cast<Dtype::Category>(d2));
} else {
throw std::invalid_argument(
"[issubdtype] Received invalid type for second input.");
}
};
if (nb::isinstance<Dtype>(d1)) {
return dispatch_second(nb::cast<Dtype>(d1), d2);
} else if (nb::isinstance<Dtype::Category>(d1)) {
return dispatch_second(nb::cast<Dtype::Category>(d1), d2);
} else {
throw std::invalid_argument(
"[issubdtype] Received invalid type for first input.");
}
},
""_a,
""_a,
nb::sig(
"def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"),
R"pbdoc(
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
of another.
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
>>> mx.issubdtype(ints.dtype, mx.integer)
True
>>> mx.issubdtype(ints.dtype, mx.floating)
False
Args:
arg1 (Union[Dtype, DtypeCategory]: First dtype or category.
arg2 (Union[Dtype, DtypeCategory]: Second dtype or category.
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
>>> mx.issubdtype(floats.dtype, mx.integer)
False
>>> mx.issubdtype(floats.dtype, mx.floating)
True
Returns:
bool:
A boolean indicating if the first input is a subtype of the
second input.
Similar types of different sizes are not subdtypes of each other:
Example:
>>> mx.issubdtype(mx.float64, mx.float32)
False
>>> mx.issubdtype(mx.float32, mx.float64)
False
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
>>> mx.issubdtype(ints.dtype, mx.integer)
True
>>> mx.issubdtype(ints.dtype, mx.floating)
False
but both are subtypes of `floating`:
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
>>> mx.issubdtype(floats.dtype, mx.integer)
False
>>> mx.issubdtype(floats.dtype, mx.floating)
True
>>> mx.issubdtype(mx.float64, mx.floating)
True
>>> mx.issubdtype(mx.float32, mx.floating)
True
Similar types of different sizes are not subdtypes of each other:
For convenience, dtype-like objects are allowed too:
>>> mx.issubdtype(mx.float64, mx.float32)
False
>>> mx.issubdtype(mx.float32, mx.float64)
False
>>> mx.issubdtype(mx.float32, mx.inexact)
True
>>> mx.issubdtype(mx.signedinteger, mx.floating)
False
but both are subtypes of `floating`:
>>> mx.issubdtype(mx.float64, mx.floating)
True
>>> mx.issubdtype(mx.float32, mx.floating)
True
For convenience, dtype-like objects are allowed too:
>>> mx.issubdtype(mx.float32, mx.inexact)
True
>>> mx.issubdtype(mx.signedinteger, mx.floating)
False
)pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
&issubdtype),
""_a,
""_a);
m.def(
"bitwise_and",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {