From b219d12a6bac637b27286bd0230babe40b0741c2 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 20 Mar 2024 11:37:58 -0700 Subject: [PATCH] Check edge case handling in row reduce med kernel (#858) --- .../kernels/reduction/kernels/reduce_row.metal | 14 ++++++++------ python/tests/test_ops.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal index 1eac8de48..499e791e2 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal @@ -108,15 +108,17 @@ template const short i_ed = short(reduction_size); const short i_jump = reductions_per_thread; - for(short r = r_st; r < r_ed; r += r_jump) { + if(r_st < r_jump) { + for(short r = r_st; r < r_ed; r += r_jump) { - uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T * in_row = in + in_idx; + uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); + const device T * in_row = in + in_idx; + + for(short i = i_st; i < i_ed; i += i_jump) { + total_val = op(static_cast(in_row[i]), total_val); + } - for(short i = i_st; i < i_ed; i += i_jump) { - total_val = op(static_cast(in_row[i]), total_val); } - } } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index cbb49e8ae..5e599c01d 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -590,6 +590,18 @@ class TestOps(mlx_tests.MLXTestCase): self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape)) self.assertTrue(np.all(sum_npy == sum_mlx)) + x_npy = ( + np.arange(3 * 2 * 3 * 3 * 3 * 3) + .reshape(3, 2, 3, 3, 3, 3) + .astype(np.float32) + ) + x_mlx = mx.array(x_npy) + + y_mlx = x_mlx.sum(axis=(0, 1, 3, 4, 5)) + y_npy = x_npy.sum(axis=(0, 1, 3, 4, 5)) + + self.assertTrue(np.array_equal(y_mlx, y_npy)) + def test_prod(self): x = mx.array( [