diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 6eac366bc..1f93a78d7 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -203,6 +203,11 @@ void time_reductions() { TIME(max_along_0); auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; TIME(max_along_1); + + auto min_along_0 = [&b]() { return mx::min(b, 0, false); }; + TIME(min_along_0); + auto min_along_1 = [&b]() { return mx::min(b, 1, false); }; + TIME(min_along_1); } void time_gather_scatter() { diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 5d2906fe7..939faf305 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -58,6 +58,13 @@ def time_max(): time_fn(mx.max, a, 0) +def time_min(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + a[1, 1] = mx.nan + mx.eval(a) + time_fn(mx.min, a, 0) + + def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) @@ -115,6 +122,7 @@ if __name__ == "__main__": time_add() time_matmul() + time_min() time_max() time_maximum() time_exp() diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 87e3aa857..8febbd050 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -350,7 +350,15 @@ struct MinReduce { }; template - T operator()(simd::Simd x) { + std::enable_if_t, T> operator()(simd::Simd x) { + return simd::min(x); + }; + + template + std::enable_if_t, T> operator()(simd::Simd x) { + if (simd::any(x != x)) { + return static_cast(NAN); + } return simd::min(x); }; }; diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index 57ddffef8..11d8e83ac 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -164,7 +164,15 @@ struct Min { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_min(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_min(val); } @@ -176,11 +184,38 @@ struct Min { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a < b ? a : b; } -}; + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a < b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a < b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real < b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + }; +}; template struct Max { DEFINE_SIMD_REDUCE() diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index d757f1527..9efd6c5c7 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -173,7 +173,7 @@ class TestReduce(mlx_tests.MLXTestCase): x[idx[0], idx[1]] = mx.nan x_np = np.array(x) - for op in ["max"]: + for op in ["max", "min"]: for axis in [0, 1]: out = getattr(mx, op)(x, axis=axis) ref = getattr(np, op)(x_np, axis=axis) @@ -205,7 +205,7 @@ class TestReduce(mlx_tests.MLXTestCase): np_arrays, ): for axis in [0, 1]: - for op in ["max"]: + for op in ["max", "min"]: out = getattr(mx, op)(mx_arr, axis=axis) ref = getattr(np, op)(np_arr, axis=axis) self.assertTrue(np.array_equal(out, ref, equal_nan=True))