Make max op NaN propagation rules align with numpy

This commit is contained in:
Joona Havukainen
2025-07-06 14:27:40 -07:00
parent 0e0d9ac522
commit 0d30e9e8ec
2 changed files with 30 additions and 2 deletions

View File

@@ -187,6 +187,9 @@ struct Max {
template <typename T> template <typename T>
T simd_reduce_impl(T val) { T simd_reduce_impl(T val) {
if(simd_any(val != val)) {
return static_cast<T>(NAN);
}
return simd_max(val); return simd_max(val);
} }
@@ -198,7 +201,28 @@ struct Max {
} }
// Operator // Operator
U operator()(U a, U b) { template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
return a > b ? a : b; return a > b ? a : b;
} }
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
if(metal::isnan(a) || metal::isnan(b)) {
return static_cast<T>(NAN);
} else {
return a > b ? a : b;
}
}
template <>
complex64_t operator()(complex64_t a, complex64_t b) {
if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) || metal::isnan(b.imag)) {
return static_cast<complex64_t>(NAN);
}
return a > b ? a : b;
}
}; };

View File

@@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") {
x = array({true, true, true, false, true, false}, {2, 3}); x = array({true, true, true, false, true, false}, {2, 3});
CHECK(array_equal(min(x, 1), array({true, false})).item<bool>()); CHECK(array_equal(min(x, 1), array({true, false})).item<bool>());
CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>()); CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>());
x = array({1.0f, NAN, 3.0f});
CHECK(isnan(max(x).item<float>()));
} }
// Test logsumexp // Test logsumexp