Add isfinite (#1318)

* isfinite

* remove reduce test since fix is not complete
This commit is contained in:
Awni Hannun
2024-08-13 14:49:28 -07:00
committed by GitHub
parent a098bc92e0
commit eaaea02010
5 changed files with 129 additions and 46 deletions

View File

@@ -87,6 +87,38 @@ struct OrReduce {
}
};
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
};
};
template <typename InT>
void reduce_dispatch_out(
const array& in,
@@ -118,15 +150,13 @@ void reduce_dispatch_out(
break;
}
case Reduce::Max: {
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, op);
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, op);
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
}
}

View File

@@ -1378,6 +1378,10 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
return logical_or(isposinf(a, s), isneginf(a, s), s);
}
array isfinite(const array& a, StreamOrDevice s /* = {} */) {
return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);

View File

@@ -399,6 +399,8 @@ array isnan(const array& a, StreamOrDevice s = {});
array isinf(const array& a, StreamOrDevice s = {});
array isfinite(const array& a, StreamOrDevice s = {});
array isposinf(const array& a, StreamOrDevice s = {});
array isneginf(const array& a, StreamOrDevice s = {});