Compare commits

...

2 Commits

Author SHA1 Message Date
Joona Havukainen
3e885f583a Cleanup using namespace alias 2025-07-07 18:25:57 -07:00
Joona Havukainen
c7af3016eb Only check nans on non-integral types in simd_reduce_impl. 2025-07-07 18:24:30 -07:00
2 changed files with 8 additions and 3 deletions

View File

@@ -193,8 +193,8 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mlx::core::array({1});
auto updates = mlx::core::reshape(mlx::core::array({NAN}), {1, 1, 1});
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);

View File

@@ -186,7 +186,12 @@ struct Max {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce_impl(T val) {
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
return simd_max(val);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
if (simd_any(val != val)) {
return static_cast<T>(NAN);
}