From 8b15773206085120728a939cd07843ef87d36545 Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Tue, 8 Jul 2025 16:41:56 -0700 Subject: [PATCH] Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16. --- mlx/backend/cpu/reduce.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index ce25feb11..87e3aa857 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -325,7 +325,15 @@ struct MaxReduce { }; template - T operator()(simd::Simd x) { + std::enable_if_t, T> operator()(simd::Simd x) { + return simd::max(x); + }; + + template + std::enable_if_t, T> operator()(simd::Simd x) { + if (simd::any(x != x)) { + return static_cast(NAN); + } return simd::max(x); }; }; @@ -527,10 +535,10 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int32: reduce_dispatch_min_max(in, out, reduce_type_, axes_);