diff --git a/mlx/backend/metal/kernels/reduction/reduce_col.h b/mlx/backend/metal/kernels/reduction/reduce_col.h index 2d102911a..f92562b84 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_col.h +++ b/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -28,10 +28,8 @@ template < looped_elem_to_loc loop; const device T* row; - // Case 1: - // reduction_stride is small, reduction_size is small and non_col_reductions - // is small. Each thread computes reduction_stride outputs. - if (reduction_size * non_col_reductions < 64) { + // Case 1: Small row small column + if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) { U totals[31]; for (int i = 0; i < 31; i++) { totals[i] = Op::init; @@ -71,10 +69,55 @@ template < } } - // Case 2: - // Reduction stride is small but everything else can be big. We loop both - // across reduction size and non_col_reductions. Each simdgroup produces - // N_READS outputs. + // Case 2: Long row small column + else if (reduction_size * non_col_reductions < 32) { + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = Op::init; + } + + short size = reduction_size; + size_t offset = size_t(tid.x) * N_READS; + bool safe = offset + N_READS <= reduction_stride; + short extra = reduction_stride - offset; + + size_t out_idx = tid.y + tsize.z * size_t(tid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + offset; + + for (uint r = 0; r < non_col_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + if (safe) { + for (short i = 0; i < size; i++) { + for (short j = 0; j < N_READS; j++) { + totals[j] = + op(static_cast(row[i * reduction_stride + j]), totals[j]); + } + } + } else { + for (short i = 0; i < size; i++) { + for (short j = 0; j < extra; j++) { + totals[j] = + op(static_cast(row[i * reduction_stride + j]), totals[j]); + } + } + } + + loop.next(reduce_shape, reduce_strides); + } + out += out_idx * reduction_stride + offset; + if (safe) { + for (short i = 0; i < N_READS; i++) { + out[i] = totals[i]; + } + } else { + for (short i = 0; i < extra; i++) { + out[i] = totals[i]; + } + } + } + + // Case 3: Long row medium column else { threadgroup U shared_vals[1024]; U totals[N_READS]; @@ -147,17 +190,13 @@ template < /** * Our approach is the following simple looped approach: * 1. Each thread keeps running totals for BN / n_simdgroups outputs. - * 2. Load a tile BM, BN in shared memory. - * 3. Add the values from shared memory to the current running totals. - * Neighboring threads access different rows (transposed acces). - * 4. Move ahead to the next tile until the M axis is exhausted. - * 5. Move ahead to the next non column reduction - * 6. Simd reduce the running totals + * 2. Load a tile BM, BN in registers and accumulate in the running totals + * 3. Move ahead by BM steps until the column axis and the non column + * reductions are exhausted. + * 6. If BM == 32 then transpose in SM and simd reduce the running totals. + * Otherwise write in shared memory and BN threads accumulate the running + * totals with a loop. * 7. Write them to the output - * - * The kernel becomes verbose because we support all kinds of OOB checks. For - * instance if we choose that reduction_stride must be larger than BN then we - * can get rid of half the kernel. */ template < typename T, diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 92eccf880..2c8f18430 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -202,7 +202,7 @@ inline int threadgroup_size_from_row_size(int row_size) { // 2 simdgroups per row for medium rows if (row_size <= 1024) { - return 64; + return 128; } // up to 32 simdgroups after that @@ -458,14 +458,25 @@ void strided_reduce_small( // Figure out the grid dims MTL::Size grid_dims, group_dims; - // Case 1: everything is small so launch one thread per col reduce - if (args.reduction_size * args.non_col_reductions < 64) { + // Case 1: Small row small column + if (args.reduction_size * args.non_col_reductions < 64 && + args.reduction_stride < 32) { grid_dims = output_grid_for_col_reduce(out, args); int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width; group_dims = MTL::Size(threadgroup_size, 1, 1); } - // Case 2: Reduction in the simdgroup + // Case 2: Long row small column + else if (args.reduction_size * args.non_col_reductions < 32) { + auto out_grid_dims = output_grid_for_col_reduce(out, args); + int threads_x = + (args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS; + int threadgroup_x = std::min(threads_x, 128); + grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height); + group_dims = MTL::Size(threadgroup_x, 1, 1); + } + + // Case 3: Long row medium column else { args.reduce_shape.push_back(args.reduction_size); args.reduce_strides.push_back(args.reduction_stride); @@ -508,7 +519,7 @@ void strided_reduce_looped( // Figure out the grid dims auto out_grid_size = output_grid_for_col_reduce(out, args); - int BN = (args.reduction_stride <= 256) ? 32 : 128; + int BN = (args.reduction_stride <= 1024) ? 32 : 128; int BM = 1024 / BN; int threadgroup_size = 4 * 32; MTL::Size grid_dims( @@ -544,7 +555,8 @@ void strided_reduce_general_dispatch( // Prepare the arguments for the kernel ColReduceArgs args(in, plan, axes); - if (args.reduction_stride < 32) { + if (args.reduction_stride < 32 || + args.reduction_size * args.non_col_reductions < 32) { return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s); } diff --git a/mlx/backend/metal/reduce.h b/mlx/backend/metal/reduce.h index a997d7e24..4d2829a9b 100644 --- a/mlx/backend/metal/reduce.h +++ b/mlx/backend/metal/reduce.h @@ -16,7 +16,8 @@ void all_reduce_dispatch( const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s); + const Stream& s, + std::vector& copies); void row_reduce_general_dispatch( const array& in, diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 0ef080f1a..f47a357dd 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -43,10 +43,6 @@ class TestReduce(mlx_tests.MLXTestCase): z_npy = np.sum(y_npy, axis=a) / 1000 z_mlx = mx.sum(y_mlx, axis=a) / 1000 mx.eval(z_mlx) - if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4): - import pdb - - pdb.set_trace() self.assertTrue( np.allclose(z_npy, np.array(z_mlx), atol=1e-4) )