From d52aa2464ee92394a5495ed5624fea259c076101 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 3 Oct 2025 16:19:33 -0700 Subject: [PATCH] Fix and refactor row-reduce --- mlx/backend/cuda/reduce/row_reduce.cu | 135 +++++++++++--------------- 1 file changed, 57 insertions(+), 78 deletions(-) diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 35f2287d6..1ecbdf698 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -83,7 +83,8 @@ struct RowReduceArgs { }; template -__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { +__global__ void +row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -91,8 +92,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { const U init = cu::ReduceInit::value(); ReduceOp op; - T vals[M][N]; - U accs[M]; + AlignedVector vals[M]; + AlignedVector accs; for (int i = 0; i < M; i++) { accs[i] = init; } @@ -101,43 +102,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { min(n_rows - M, static_cast(grid.block_rank() * M)); const size_t full_blocks = size / (block.size() * N); const size_t final_offset = full_blocks * (block.size() * N); - in += start_row * size; + in += start_row * size + block.thread_rank() * N; out += start_row; - if (size % N == 0) { - for (size_t r = 0; r < full_blocks; r++) { - for (int k = 0; k < M; k++) { - cub::LoadDirectBlockedVectorized( - block.thread_rank(), - in + k * size + r * (block.size() * N), - vals[k]); - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(vals[k][j])); - } - } - } - } else { - for (size_t r = 0; r < full_blocks; r++) { - for (int k = 0; k < M; k++) { - cub::LoadDirectBlocked( - block.thread_rank(), - in + k * size + r * (block.size() * N), - vals[k]); - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(vals[k][j])); - } + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + vals[k] = load_vector(in + k * size, 0); + } + for (int k = 0; k < M; k++) { + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(vals[k][j])); } } + + in += block.size() * N; } if (final_offset < size) { for (int k = 0; k < M; k++) { - cub::LoadDirectBlocked( - block.thread_rank(), - in + k * size + final_offset, - vals[k], - size, - cast_to(init)); + for (int i = 0; i < N; i++) { + vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) + ? in[k * size + i] + : cast_to(init); + } + } + for (int k = 0; k < M; k++) { for (int j = 0; j < N; j++) { accs[k] = op(accs[k], cast_to(vals[k][j])); } @@ -145,13 +134,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { } __shared__ U shared_accumulators[32 * M]; - block_reduce(block, warp, accs, shared_accumulators, op, init); + block_reduce(block, warp, accs.val, shared_accumulators, op, init); if (block.thread_rank() == 0) { if (grid.block_rank() * M + M <= n_rows) { - for (int i = 0; i < M; i++) { - out[i] = accs[i]; - } + store_vector(out, 0, accs); } else { short offset = grid.block_rank() * M + M - n_rows; for (int i = offset; i < M; i++) { @@ -161,17 +148,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { } } -template < - typename T, - typename U, - typename Op, - int NDIM, - int BLOCK_DIM, - int N_READS = 4> +template __global__ void row_reduce_looped( - T* in, + const T* in, U* out, - size_t out_size, const __grid_constant__ RowReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); @@ -185,30 +165,29 @@ __global__ void row_reduce_looped( U init = ReduceInit::value(); total[0] = init; LoopedElemToLoc 2)> loop(args.reduce_ndim); - size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); - size_t final_offset = full_blocks * BLOCK_DIM * N_READS; + size_t full_blocks = args.row_size / (block.size() * N_READS); + size_t final_offset = full_blocks * (block.size() * N_READS); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + in += block.thread_rank() * N_READS; for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); for (size_t r = 0; r < full_blocks; r++) { - T vals[N_READS]; - cub::LoadDirectBlockedVectorized( - block.thread_rank(), - in + loop.location() + r * BLOCK_DIM * N_READS, - vals); + auto vals = load_vector(inlocal, 0); for (int i = 0; i < N_READS; i++) { total[0] = op(total[0], cast_to(vals[i])); } + inlocal += block.size() * N_READS; } if (final_offset < args.row_size) { T vals[N_READS]; - cub::LoadDirectBlocked( - block.thread_rank(), - in + loop.location() + final_offset, - vals, - args.row_size - final_offset, - cast_to(init)); + for (int i = 0; i < N_READS; i++) { + vals[i] = + ((final_offset + block.thread_rank() * N_READS + i) < args.row_size) + ? inlocal[i] + : cast_to(init); + } for (int i = 0; i < N_READS; i++) { total[0] = op(total[0], cast_to(vals[i])); } @@ -234,8 +213,6 @@ void row_reduce_simple( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { - constexpr int N_READS = 8; - // Allocate data for the output using in's layout to avoid elem_to_loc in the // kernel. allocate_same_layout(out, in, axes); @@ -250,14 +227,26 @@ void row_reduce_simple( using T = cuda_type_t; using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + constexpr int N_READS = 16 / sizeof(T); // Calculate the grid and block dims size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; + if (warps > 128) { + warps = 32; + } else { + warps = 16; + } + int best = reductions; + for (int j = warps; j >= 1; j /= 2) { + int t = reductions % (j * WARP_SIZE); + if (t < best) { + warps = j; + best = t; + } + } + int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); // Pick the kernel @@ -267,6 +256,7 @@ void row_reduce_simple( kernel = cu::row_reduce_simple; } + T* indata = const_cast(in.data()); int size = plan.shape.back(); encoder.add_kernel_node( kernel, grid, block, 0, indata, out.data(), out.size(), size); @@ -295,8 +285,6 @@ void row_reduce_looped( using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); // Calculate the grid and block dims args.sort_access_pattern(in, axes); @@ -307,22 +295,13 @@ void row_reduce_looped( dim3 block(threads, 1, 1); // Pick the kernel - auto kernel = cu::row_reduce_looped; + auto kernel = cu::row_reduce_looped; dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - dispatch_block_dim(threads, [&](auto threads_constant) { - kernel = cu::row_reduce_looped< - T, - U, - OP, - reduce_ndim.value, - threads_constant.value, - N_READS>; - block.x = threads_constant.value; - }); + kernel = cu::row_reduce_looped; }); encoder.add_kernel_node( - kernel, grid, block, 0, indata, out.data(), out.size(), args); + kernel, grid, block, 0, in.data(), out.data(), args); }); }); }