From 17432e7885aeea66cdb3c0ac79bc191c2be939e0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 1 Oct 2025 21:11:35 -0700 Subject: [PATCH] Add a small column specialization to reduce --- mlx/backend/cuda/reduce/col_reduce.cu | 88 ++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 04c400c47..eb325a987 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { } } +template +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + size_t total) { + Op op; + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + const auto idx = grid.thread_rank() * N_READS; + const auto before_axis = idx / args.reduction_stride; + const auto after_axis = idx % args.reduction_stride; + const auto offset = + before_axis * args.reduction_stride * args.reduction_size + after_axis; + + if (idx >= total) { + return; + } + + in += offset; + out += idx; + + AlignedVector accumulator; + for (int i = 0; i < N_READS; i++) { + accumulator[i] = ReduceInit::value(); + } + + for (int i = 0; i < args.reduction_size; i++) { + auto values = load_vector(in, 0); + + for (int j = 0; j < N_READS; j++) { + accumulator[j] = op(accumulator[j], cast_to(values[j])); + } + + in += args.reduction_stride; + } + + store_vector(out, 0, accumulator); +} + } // namespace cu inline auto output_grid_for_col_reduce( @@ -236,6 +277,43 @@ void col_reduce_looped( }); } +void col_reduce_small( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::ColReduceArgs args) { + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + + constexpr int N_READS = 16 / sizeof(T); + auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides()); + auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1); + auto kernel = cu::col_reduce_small; + encoder.add_kernel_node( + kernel, + grid, + block, + 0, + in.data(), + out.data(), + args, + out.size()); + }); + }); +} + void col_reduce( cu::CommandEncoder& encoder, const array& in, @@ -258,8 +336,16 @@ void col_reduce( // Make the args struct to help route to the best kernel cu::ColReduceArgs args(in, plan, axes); + // Small col reduce with a single or contiguous reduction axis + if (args.non_col_reductions == 1 && args.reduction_size <= 32 && + args.reduction_stride % 4 == 0) { + col_reduce_small( + encoder, in, out, reduce_type, axes, plan, std::move(args)); + return; + } + // Fallback col reduce - col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); + col_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); } } // namespace mlx::core