From 8b9a3f3ceab86a6e79dde498b20c0142328bf4a4 Mon Sep 17 00:00:00 2001 From: jhavukainen <104022140+jhavukainen@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:26:27 -0700 Subject: [PATCH] Align mlx::core::max op nan propagation with NumPy (#2339) * Make max op NaN propagation rules align with numpy * Adding benchmarks and testing for max op nanpropagation * Pre-commit formatting * Fix max complex64 nan propagation and add test * Improve the cpp unittest * Only check nans on non-integral types in simd_reduce_impl. * Cleanup using namespace alias * Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16. * Make the max nanpropagation test more meaningful for integer types * Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR. --- benchmarks/cpp/single_ops.cpp | 11 +++++ benchmarks/python/single_ops.py | 8 ++++ mlx/backend/cpu/reduce.cpp | 14 ++++-- mlx/backend/metal/kernels/reduction/ops.h | 40 +++++++++++++++- python/tests/cuda_skip.py | 2 + python/tests/test_reduce.py | 57 +++++++++++++++++++++++ tests/ops_tests.cpp | 4 ++ 7 files changed, 131 insertions(+), 5 deletions(-) diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 5b327be58..6eac366bc 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -192,6 +192,17 @@ void time_reductions() { auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; TIME(argmin_along_1); + + auto indices = mx::array({1}); + auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1}); + std::vector axes{0}; + auto b = scatter(a, {indices}, updates, axes); + mx::eval(b); + + auto max_along_0 = [&b]() { return mx::max(b, 0, false); }; + TIME(max_along_0); + auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; + TIME(max_along_1); } void time_gather_scatter() { diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 3160a1833..5d2906fe7 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -51,6 +51,13 @@ def time_maximum(): time_fn(mx.maximum, a, b) +def time_max(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + a[1, 1] = mx.nan + mx.eval(a) + time_fn(mx.max, a, 0) + + def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) @@ -108,6 +115,7 @@ if __name__ == "__main__": time_add() time_matmul() + time_max() time_maximum() time_exp() time_negative() diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index ce25feb11..87e3aa857 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -325,7 +325,15 @@ struct MaxReduce { }; template - T operator()(simd::Simd x) { + std::enable_if_t, T> operator()(simd::Simd x) { + return simd::max(x); + }; + + template + std::enable_if_t, T> operator()(simd::Simd x) { + if (simd::any(x != x)) { + return static_cast(NAN); + } return simd::max(x); }; }; @@ -527,10 +535,10 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index 68ed11986..57ddffef8 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -186,7 +186,15 @@ struct Max { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_max(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_max(val); } @@ -198,7 +206,35 @@ struct Max { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a > b ? a : b; } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a > b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a > b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real > b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + } }; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 17eb80eee..afd48bd03 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -3,6 +3,8 @@ cuda_skip = { "TestLayers.test_quantized_embedding", "TestOps.test_dynamic_slicing", "TestReduce.test_dtypes", + "TestReduce.test_nanpropagation", + "TestReduce.test_nanpropagation_complex64", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 2b899c099..d757f1527 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -153,6 +153,63 @@ class TestReduce(mlx_tests.MLXTestCase): x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9) check(x, (1, 3, 5, 7, 9)) + def test_nanpropagation(self): + dtypes = [ + "uint8", + "uint16", + "uint32", + "int8", + "int16", + "int32", + "float16", + "float32", + ] + + for dtype in dtypes: + with self.subTest(dtype=dtype): + x = (mx.random.normal((4, 4)) * 10).astype(getattr(mx, dtype)) + indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2) + for idx in indices: + x[idx[0], idx[1]] = mx.nan + x_np = np.array(x) + + for op in ["max"]: + for axis in [0, 1]: + out = getattr(mx, op)(x, axis=axis) + ref = getattr(np, op)(x_np, axis=axis) + self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + + def test_nanpropagation_complex64(self): + complex_array_1 = mx.array( + [1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_2 = mx.array( + [1 + 1j, 2 + 2j, 3 + mx.nan * 1j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_3 = mx.array( + [1 + 1j, 2 + mx.nan * 1j, 3 + 3j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_4 = mx.array( + [mx.nan + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + + np_arrays = [ + np.array(complex_array_1), + np.array(complex_array_2), + np.array(complex_array_3), + np.array(complex_array_4), + ] + + for mx_arr, np_arr in zip( + [complex_array_1, complex_array_2, complex_array_3, complex_array_4], + np_arrays, + ): + for axis in [0, 1]: + for op in ["max"]: + out = getattr(mx, op)(mx_arr, axis=axis) + ref = getattr(np, op)(np_arr, axis=axis) + self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 8833424a6..1a9781c7c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") { x = array({true, true, true, false, true, false}, {2, 3}); CHECK(array_equal(min(x, 1), array({true, false})).item()); CHECK(array_equal(min(x, 0), array({false, true, false})).item()); + + x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); + CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item()); + CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item()); } // Test logsumexp