diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 24e609d92..5d2906fe7 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -50,12 +50,14 @@ def time_maximum(): mx.eval(a, b) time_fn(mx.maximum, a, b) + def time_max(): a = mx.random.uniform(shape=(32, 1024, 1024)) - a[1,1] = mx.nan + a[1, 1] = mx.nan mx.eval(a) time_fn(mx.max, a, 0) + def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index d2083c5bc..a3f3202fa 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -187,10 +187,10 @@ struct Max { template T simd_reduce_impl(T val) { - if(simd_any(val != val)) { + if (simd_any(val != val)) { return static_cast(NAN); } - return simd_max(val); + return simd_max(val); } static constexpr constant U init = Limits::min; @@ -208,21 +208,19 @@ struct Max { template metal::enable_if_t, T> operator()(T a, T b) { - if(metal::isnan(a) || metal::isnan(b)) { + if (metal::isnan(a) || metal::isnan(b)) { return static_cast(NAN); } else { return a > b ? a : b; } } - template <> complex64_t operator()(complex64_t a, complex64_t b) { - if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) || metal::isnan(b.imag)) { + if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) || + metal::isnan(b.imag)) { return static_cast(NAN); - } + } return a > b ? a : b; - } - }; diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 1617257d7..6b2ff06b8 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -168,7 +168,7 @@ class TestReduce(mlx_tests.MLXTestCase): for dtype in dtypes: with self.subTest(dtype=dtype): x = (mx.random.normal((4, 4))).astype(getattr(mx, dtype)) - indices = mx.random.randint(0, 4, shape=(6,)).reshape(3,2) + indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2) for idx in indices: x[*idx] = mx.nan x_np = np.array(x) @@ -179,5 +179,6 @@ class TestReduce(mlx_tests.MLXTestCase): ref = getattr(np, op)(x_np, axis=axis) self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 156d50838..af3e3a698 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1027,7 +1027,6 @@ TEST_CASE("test reduction ops") { x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}); CHECK(isnan(max(x).item())); - } // Test logsumexp