diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index c89a104a0..f62c67d38 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -217,14 +217,20 @@ Simd atan2(Simd a, Simd b) { template Simd maximum(Simd a, Simd b) { - // TODO add isnan - return asd::max(a.value, b.value); + auto out = Simd(asd::max(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; } template Simd minimum(Simd a, Simd b) { - // TODO add isnan - return asd::min(a.value, b.value); + auto out = Simd(asd::min(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; } template diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c473b59c3..1b9506622 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -4052,3 +4052,24 @@ TEST_CASE("test fp8 conversion") { auto expected = array({-448.0f, 448.0f}); CHECK(array_equal(out, expected, true).item()); } + +TEST_CASE("test max min with nan") { + // Test maximum and minimum with NaN values + auto x = array({0.0f, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto y = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto expected_max = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto expected_min = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto max_result = maximum(x, y); + auto min_result = minimum(x, y); + CHECK(array_equal(max_result, expected_max, true).item()); + CHECK(array_equal(min_result, expected_min, true).item()); + + // Test with all NaN values + x = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + y = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + max_result = maximum(x, y); + min_result = minimum(x, y); + auto expected = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + CHECK(array_equal(max_result, expected, true).item()); + CHECK(array_equal(min_result, expected, true).item()); +}