From 1ff2b713b67a519a9c39b983156e0794efd7a117 Mon Sep 17 00:00:00 2001 From: AN Long Date: Tue, 4 Nov 2025 01:51:14 +0900 Subject: [PATCH] Check isnan in maximum / minimum with CPU backend (#2652) * Check isnan in maximum / minimum with CPU backend * Add tests * fix --------- Co-authored-by: Awni Hannun --- mlx/backend/cpu/simd/accelerate_simd.h | 14 ++++++++++---- tests/ops_tests.cpp | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index c89a104a0..f62c67d38 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -217,14 +217,20 @@ Simd atan2(Simd a, Simd b) { template Simd maximum(Simd a, Simd b) { - // TODO add isnan - return asd::max(a.value, b.value); + auto out = Simd(asd::max(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; } template Simd minimum(Simd a, Simd b) { - // TODO add isnan - return asd::min(a.value, b.value); + auto out = Simd(asd::min(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; } template diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c473b59c3..1b9506622 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -4052,3 +4052,24 @@ TEST_CASE("test fp8 conversion") { auto expected = array({-448.0f, 448.0f}); CHECK(array_equal(out, expected, true).item()); } + +TEST_CASE("test max min with nan") { + // Test maximum and minimum with NaN values + auto x = array({0.0f, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto y = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto expected_max = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto expected_min = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto max_result = maximum(x, y); + auto min_result = minimum(x, y); + CHECK(array_equal(max_result, expected_max, true).item()); + CHECK(array_equal(min_result, expected_min, true).item()); + + // Test with all NaN values + x = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + y = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + max_result = maximum(x, y); + min_result = minimum(x, y); + auto expected = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + CHECK(array_equal(max_result, expected, true).item()); + CHECK(array_equal(min_result, expected, true).item()); +}