From ed61cb2802bf8647facb52a8055064567f431fb0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 3 Oct 2025 16:54:21 -0700 Subject: [PATCH] Use the same tuning for looped --- mlx/backend/cuda/reduce/row_reduce.cu | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 45bf53575..960872982 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -270,8 +270,6 @@ void row_reduce_looped( const std::vector& axes, const ReductionPlan& plan, cu::RowReduceArgs args) { - constexpr int N_READS = 8; - // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes); @@ -284,12 +282,27 @@ void row_reduce_looped( using T = cuda_type_t; using U = typename cu::ReduceResult::type; + constexpr int N_READS = 16 / sizeof(T); + // Calculate the grid and block dims args.sort_access_pattern(in, axes); dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); size_t reductions = (args.row_size + N_READS - 1) / N_READS; - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; + if (warps > 128) { + warps = 32; + } else { + warps = 16; + } + int best = reductions; + for (int j = warps; j >= 1; j /= 2) { + int t = reductions % (j * WARP_SIZE); + if (t < best) { + warps = j; + best = t; + } + } + int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); // Pick the kernel