From 496315fe1d0206fbbf2fb07b562b094b2283acd1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 5 Jun 2024 14:21:58 -0700 Subject: [PATCH] Fix scan (#1188) * fix scan * improve grid size * fix cpu cummax --- mlx/backend/common/scan.cpp | 2 +- mlx/backend/metal/kernels/scan.h | 1 + mlx/backend/metal/scan.cpp | 14 +++++++------- python/tests/test_ops.py | 14 ++++++++++++++ 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/mlx/backend/common/scan.cpp b/mlx/backend/common/scan.cpp index 221475902..153375aef 100644 --- a/mlx/backend/common/scan.cpp +++ b/mlx/backend/common/scan.cpp @@ -234,7 +234,7 @@ void scan_dispatch( auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; auto init = (issubdtype(input.dtype(), floating)) ? static_cast(-std::numeric_limits::infinity()) - : std::numeric_limits::max(); + : std::numeric_limits::min(); auto opcs = DefaultContiguousScan(op, init); auto opss = DefaultStridedScan(op, init); scan_op(opcs, opss, input, output, axis, reverse, inclusive); diff --git a/mlx/backend/metal/kernels/scan.h b/mlx/backend/metal/kernels/scan.h index 19e4be86c..67b27ba89 100644 --- a/mlx/backend/metal/kernels/scan.h +++ b/mlx/backend/metal/kernels/scan.h @@ -309,6 +309,7 @@ template < } } } + threadgroup_barrier(mem_flags::mem_threadgroup); // Share the prefix if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index e35ecfbe1..63ca132ff 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -65,16 +65,16 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { // Compute the thread grid int n_reads = (in.itemsize() <= 4) ? 4 : 2; - int elements_per_simd = n_reads * 32; + constexpr int simd_size = 32; + int elements_per_simd = n_reads * simd_size; int thread_groups = in.size() / size; int thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (size < n_reads * 1024) { - thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) * - elements_per_simd; - } else if (size < n_reads * 2048) { + if (size <= n_reads * 1024) { thread_group_size = - ((size / 2 + elements_per_simd - 1) / elements_per_simd) * - elements_per_simd; + ((size + elements_per_simd - 1) / elements_per_simd) * simd_size; + } else if (size <= n_reads * 2048) { + thread_group_size = + ((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size; } thread_group_size = std::min( thread_group_size, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index dc89a4afd..c41c79c83 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1678,7 +1678,9 @@ class TestOps(mlx_tests.MLXTestCase): c_mlx = mxop(a_mlx, axis=axis) self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100) for op in ["cumsum", "cumprod", "cummax", "cummin"]: + mxop = getattr(mx, op) c1 = mxop(a_mlx, axis=2) c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False) self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:])) @@ -1719,6 +1721,18 @@ class TestOps(mlx_tests.MLXTestCase): out = mx.cumsum(a_t, axis=-1) expected = (mat_t * a_t[:, None, :]).sum(axis=-1) self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3)) + sizes = [1023, 1024, 1025, 2047, 2048, 2049] + for s in sizes: + a = mx.ones((s,), mx.int32) + out = mx.cumsum(a) + expected = mx.arange(1, s + 1, dtype=mx.int32) + self.assertTrue(mx.array_equal(expected, out)) + + # non-contiguous scan + a = mx.ones((s, 2), mx.int32) + out = mx.cumsum(a, axis=0) + expected = mx.repeat(expected[:, None], 2, axis=1) + self.assertTrue(mx.array_equal(expected, out)) def test_squeeze_expand(self): a = mx.zeros((2, 1, 2, 1))