mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Make max op NaN propagation rules align with numpy
This commit is contained in:
@@ -187,7 +187,10 @@ struct Max {
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_max(val);
|
||||
if(simd_any(val != val)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd_max(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
@@ -198,7 +201,28 @@ struct Max {
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
if(metal::isnan(a) || metal::isnan(b)) {
|
||||
return static_cast<T>(NAN);
|
||||
} else {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t a, complex64_t b) {
|
||||
if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) || metal::isnan(b.imag)) {
|
||||
return static_cast<complex64_t>(NAN);
|
||||
}
|
||||
return a > b ? a : b;
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
@@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") {
|
||||
x = array({true, true, true, false, true, false}, {2, 3});
|
||||
CHECK(array_equal(min(x, 1), array({true, false})).item<bool>());
|
||||
CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>());
|
||||
|
||||
x = array({1.0f, NAN, 3.0f});
|
||||
CHECK(isnan(max(x).item<float>()));
|
||||
|
||||
}
|
||||
|
||||
// Test logsumexp
|
||||
|
||||
Reference in New Issue
Block a user