From cad47a32e2aad8b9b2bfcaf5f3c5cb392dc66dcf Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 6 Oct 2025 23:28:29 -0700 Subject: [PATCH] Re-tune --- mlx/backend/cuda/reduce/row_reduce.cu | 93 ++++++++++++++------------- 1 file changed, 50 insertions(+), 43 deletions(-) diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 960872982..2e841cfcd 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -163,35 +163,60 @@ __global__ void row_reduce_looped( U init = ReduceInit::value(); total[0] = init; LoopedElemToLoc 2)> loop(args.reduce_ndim); - size_t full_blocks = args.row_size / (block.size() * N_READS); - size_t final_offset = full_blocks * (block.size() * N_READS); + const size_t full_blocks = args.row_size / (block.size() * N_READS); + const size_t final_offset = full_blocks * (block.size() * N_READS); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); in += block.thread_rank() * N_READS; - for (size_t n = 0; n < args.non_row_reductions; n++) { - const T* inlocal = in + loop.location(); - for (size_t r = 0; r < full_blocks; r++) { - auto vals = load_vector(inlocal, 0); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(vals[i])); - } - inlocal += block.size() * N_READS; + // Unaligned reduce + if (final_offset < args.row_size) { + bool mask[N_READS]; + for (int i = 0; i < N_READS; i++) { + mask[i] = + (final_offset + block.thread_rank() * N_READS + i) < args.row_size; } - if (final_offset < args.row_size) { - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = - ((final_offset + block.thread_rank() * N_READS + i) < args.row_size) - ? inlocal[i] - : cast_to(init); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + inlocal += block.size() * N_READS; } - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(vals[i])); + + { + T vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = mask[i] ? inlocal[i] : cast_to(init); + } + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + // Aligned case + else { + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + inlocal += block.size() * N_READS; + } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } - // TODO: Maybe block.sync() here? - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } __shared__ U shared_accumulators[32]; @@ -231,18 +256,9 @@ void row_reduce_simple( size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; - if (warps > 128) { + warps /= 4; + if (warps > 32) { warps = 32; - } else { - warps = 16; - } - int best = reductions; - for (int j = warps; j >= 1; j /= 2) { - int t = reductions % (j * WARP_SIZE); - if (t < best) { - warps = j; - best = t; - } } int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); @@ -289,18 +305,9 @@ void row_reduce_looped( dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); size_t reductions = (args.row_size + N_READS - 1) / N_READS; int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; - if (warps > 128) { + warps /= 4; + if (warps > 32) { warps = 32; - } else { - warps = 16; - } - int best = reductions; - for (int j = warps; j >= 1; j /= 2) { - int t = reductions % (j * WARP_SIZE); - if (t < best) { - warps = j; - best = t; - } } int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1);