Align mlx::core::min op nan propagation with NumPy (#2346)

This commit is contained in:
jhavukainen
2025-07-10 06:20:43 -07:00
committed by GitHub
parent 85873cb162
commit 8c7bc30ce4
5 changed files with 62 additions and 6 deletions

View File

@@ -350,7 +350,15 @@ struct MinReduce {
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::min(x);
};
};