Make max op NaN propagation rules align with numpy

This commit is contained in:
Joona Havukainen
2025-07-06 14:27:40 -07:00
parent 0e0d9ac522
commit 0d30e9e8ec
2 changed files with 30 additions and 2 deletions

View File

@@ -187,6 +187,9 @@ struct Max {
template <typename T>
T simd_reduce_impl(T val) {
if(simd_any(val != val)) {
return static_cast<T>(NAN);
}
return simd_max(val);
}
@@ -198,7 +201,28 @@ struct Max {
}
// Operator
U operator()(U a, U b) {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
return a > b ? a : b;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
if(metal::isnan(a) || metal::isnan(b)) {
return static_cast<T>(NAN);
} else {
return a > b ? a : b;
}
}
template <>
complex64_t operator()(complex64_t a, complex64_t b) {
if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) || metal::isnan(b.imag)) {
return static_cast<complex64_t>(NAN);
}
return a > b ? a : b;
}
};

View File

@@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") {
x = array({true, true, true, false, true, false}, {2, 3});
CHECK(array_equal(min(x, 1), array({true, false})).item<bool>());
CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>());
x = array({1.0f, NAN, 3.0f});
CHECK(isnan(max(x).item<float>()));
}
// Test logsumexp