mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 13:07:29 +08:00
Add isfinite
(#1318)
* isfinite * remove reduce test since fix is not complete
This commit is contained in:
@@ -1973,6 +1973,27 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are +/- infinity.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"isfinite",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
return mlx::core::isfinite(to_array(a), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return a boolean array indicating which elements are finite.
|
||||
|
||||
An element is finite if it is not infinite or NaN.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are finite.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"isposinf",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
@@ -4254,62 +4275,78 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
|
||||
[](const nb::object& d1, const nb::object& d2) {
|
||||
auto dispatch_second = [](const auto& t1, const auto& d2) {
|
||||
if (nb::isinstance<Dtype>(d2)) {
|
||||
return issubdtype(t1, nb::cast<Dtype>(d2));
|
||||
} else if (nb::isinstance<Dtype::Category>(d2)) {
|
||||
return issubdtype(t1, nb::cast<Dtype::Category>(d2));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[issubdtype] Received invalid type for second input.");
|
||||
}
|
||||
};
|
||||
if (nb::isinstance<Dtype>(d1)) {
|
||||
return dispatch_second(nb::cast<Dtype>(d1), d2);
|
||||
} else if (nb::isinstance<Dtype::Category>(d1)) {
|
||||
return dispatch_second(nb::cast<Dtype::Category>(d1), d2);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[issubdtype] Received invalid type for first input.");
|
||||
}
|
||||
},
|
||||
""_a,
|
||||
""_a,
|
||||
nb::sig(
|
||||
"def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"),
|
||||
R"pbdoc(
|
||||
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
||||
of another.
|
||||
|
||||
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||
True
|
||||
>>> mx.issubdtype(ints.dtype, mx.floating)
|
||||
False
|
||||
Args:
|
||||
arg1 (Union[Dtype, DtypeCategory]: First dtype or category.
|
||||
arg2 (Union[Dtype, DtypeCategory]: Second dtype or category.
|
||||
|
||||
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
|
||||
>>> mx.issubdtype(floats.dtype, mx.integer)
|
||||
False
|
||||
>>> mx.issubdtype(floats.dtype, mx.floating)
|
||||
True
|
||||
Returns:
|
||||
bool:
|
||||
A boolean indicating if the first input is a subtype of the
|
||||
second input.
|
||||
|
||||
Similar types of different sizes are not subdtypes of each other:
|
||||
Example:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.float32)
|
||||
False
|
||||
>>> mx.issubdtype(mx.float32, mx.float64)
|
||||
False
|
||||
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||
True
|
||||
>>> mx.issubdtype(ints.dtype, mx.floating)
|
||||
False
|
||||
|
||||
but both are subtypes of `floating`:
|
||||
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
|
||||
>>> mx.issubdtype(floats.dtype, mx.integer)
|
||||
False
|
||||
>>> mx.issubdtype(floats.dtype, mx.floating)
|
||||
True
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.floating)
|
||||
True
|
||||
>>> mx.issubdtype(mx.float32, mx.floating)
|
||||
True
|
||||
Similar types of different sizes are not subdtypes of each other:
|
||||
|
||||
For convenience, dtype-like objects are allowed too:
|
||||
>>> mx.issubdtype(mx.float64, mx.float32)
|
||||
False
|
||||
>>> mx.issubdtype(mx.float32, mx.float64)
|
||||
False
|
||||
|
||||
>>> mx.issubdtype(mx.float32, mx.inexact)
|
||||
True
|
||||
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||
False
|
||||
but both are subtypes of `floating`:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.floating)
|
||||
True
|
||||
>>> mx.issubdtype(mx.float32, mx.floating)
|
||||
True
|
||||
|
||||
For convenience, dtype-like objects are allowed too:
|
||||
|
||||
>>> mx.issubdtype(mx.float32, mx.inexact)
|
||||
True
|
||||
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||
False
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
|
||||
&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"bitwise_and",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
Reference in New Issue
Block a user