From cc4b995723acabe3d381f210af61193f21f9fa1f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 21 Jun 2025 23:39:40 -0700 Subject: [PATCH] Working col reduce --- mlx/backend/cuda/reduce/col_reduce.cu | 255 +++++++++++--------------- 1 file changed, 110 insertions(+), 145 deletions(-) diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 9911a6fe0..bbfed594d 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -64,86 +64,6 @@ struct ColReduceArgs { } }; -template -__global__ void col_reduce_small( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - - int column = - grid.block_index().x * block.dim_threads().x + block.thread_index().x; - if (column * N_READS >= args.reduction_stride) { - return; - } - - int out_idx = grid.block_rank() / grid.dim_blocks().x; - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - Op op; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = ReduceInit::value(); - } - - // Read input to local. - LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next( - block.thread_index().y, - args.reduce_shape.data(), - args.reduce_strides.data()); - for (size_t r = block.thread_index().y; - r < args.non_col_reductions * args.reduction_size; - r += block.dim_threads().y) { - U vals[N_READS]; - cub::LoadDirectBlocked( - column, - make_cast_iterator(in + loop.location()), - vals, - args.reduction_stride, - ReduceInit::value()); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); - } - loop.next( - block.dim_threads().y, - args.reduce_shape.data(), - args.reduce_strides.data()); - } - - // Do block reduce when each column has more than 1 element to reduce. - if (block.dim_threads().y > 1) { - __shared__ U shared_vals[32 * 8 * N_READS]; - size_t col = - block.thread_index().y * block.dim_threads().x + block.thread_index().x; - for (int i = 0; i < N_READS; i++) { - shared_vals[col * N_READS + i] = totals[i]; - } - block.sync(); - if (block.thread_index().y == 0) { - for (int i = 0; i < N_READS; i++) { - totals[i] = shared_vals[block.thread_index().x * N_READS + i]; - } - for (int j = 1; j < block.dim_threads().y; j++) { - col = j * block.dim_threads().x + block.thread_index().x; - for (int i = 0; i < N_READS; i++) { - totals[i] = op(shared_vals[col * N_READS + i], totals[i]); - } - } - } - } - - // Write result. - if (block.thread_index().y == 0) { - cub::StoreDirectBlocked( - column, - out + out_idx * args.reduction_stride, - totals, - args.reduction_stride); - } -} - template < typename T, typename U, @@ -152,67 +72,83 @@ template < int BM, int BN, int N_READS = 4> -__global__ void col_reduce_looped( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args) { +__global__ void +col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - constexpr int n_warps = BN / N_READS; + constexpr int threads_per_row = BN / N_READS; - int out_idx = grid.block_rank() / grid.dim_blocks().x; - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + // Compute the indices for the tile + size_t tile_idx = grid.block_rank(); + size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); + size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + // Compute the indices for the thread within the tile + short thread_x = block.thread_rank() % threads_per_row; + short thread_y = block.thread_rank() / threads_per_row; + + // Move the input pointer + in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + + tile_x * BN; + + // Initialize the running totals Op op; U totals[N_READS]; for (int i = 0; i < N_READS; i++) { totals[i] = ReduceInit::value(); } - // Read input to local. - int r = block.thread_rank() / n_warps; - int column = block.thread_rank() % n_warps; - int in_offset = grid.block_index().x * BN; LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); - for (; r < args.non_col_reductions * args.reduction_size; r += BM) { - U vals[N_READS]; - cub::LoadDirectBlocked( - column, - make_cast_iterator(in + loop.location() + in_offset), - vals, - args.reduction_stride - in_offset, - ReduceInit::value()); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); + loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data()); + size_t total = args.non_col_reductions * args.reduction_size; + if (tile_x * BN + BN <= args.reduction_stride) { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked( + thread_x, + in + loop.location(), + vals, + args.reduction_stride - tile_x * BN, + __cast(ReduceInit::value())); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } // Do warp reduce for each output. - constexpr int n_outputs = BN / n_warps; + constexpr int n_outputs = BN / threads_per_row; static_assert(BM == 32 && n_outputs == N_READS); __shared__ U shared_vals[BM * BN]; - size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; + short s_idx = thread_y * BN + thread_x * N_READS; for (int i = 0; i < N_READS; i++) { - shared_vals[col + i] = totals[i]; + shared_vals[s_idx + i] = totals[i]; } block.sync(); - col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; for (int i = 0; i < n_outputs; i++) { - totals[i] = cg::reduce(warp, shared_vals[col + i], op); + totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); } // Write result. if (warp.thread_rank() == 0) { - size_t out_offset = grid.block_index().x * BN; cub::StoreDirectBlocked( warp.meta_group_rank(), - out + out_idx * args.reduction_stride + out_offset, + out + tile_y * args.reduction_stride + tile_x * BN, totals, - args.reduction_stride - out_offset); + args.reduction_stride - tile_x * BN); } } @@ -230,6 +166,53 @@ inline auto output_grid_for_col_reduce( return get_2d_grid_dims(out_shape, out_strides); } +void col_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::ColReduceArgs args) { + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + // Just a way to get out of the constness because cub doesn't like it ... + // (sigh) + array x = in; + + encoder.set_input_array(x); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args); + size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN); + if (grid.x * extra_blocks < INT32_MAX) { + grid.x *= extra_blocks; + } else if (grid.y * extra_blocks < 65536) { + grid.y *= extra_blocks; + } else { + throw std::runtime_error( + "[col_reduce_looped] Need to factorize reduction_stride"); + } + int blocks = BM * BN / N_READS; + auto kernel = cu::col_reduce_looped; + kernel<<>>(x.data(), out.data(), args); + }); + }); + }); + }); +} + void col_reduce( cu::CommandEncoder& encoder, const array& in, @@ -237,42 +220,24 @@ void col_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { + // Current col reduce options + // + // - col_reduce_looped + // + // It is a general strided reduce. Each threadblock computes the output for + // a subrow of the fast moving axis. For instance 32 elements. + // + // Notes: As in row reduce we opt to read as much in order as possible and + // leave + // transpositions as they are (contrary to our Metal backend). + // + // Moreover we need different kernels for short rows and tuning + + // Make the args struct to help route to the best kernel cu::ColReduceArgs args(in, plan, axes); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - using InType = cuda_type_t; - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using OutType = cu::ReduceResult::type; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - constexpr int N_READS = 4; - dim3 block_dims; - dim3 num_blocks = output_grid_for_col_reduce(out, args); - num_blocks.z = num_blocks.y; - num_blocks.y = num_blocks.x; - auto kernel = - cu::col_reduce_small; - size_t total = args.non_col_reductions * args.reduction_size; - if (total < 32) { - size_t stride_blocks = - cuda::ceil_div(args.reduction_stride, N_READS); - block_dims.x = std::min(stride_blocks, 32ul); - block_dims.y = std::min(total, 8ul); - num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x); - } else { - constexpr int BM = 32; - constexpr int BN = 32; - block_dims.x = BM * BN / N_READS; - num_blocks.x = cuda::ceil_div(args.reduction_stride, BN); - kernel = cu:: - col_reduce_looped; - } - kernel<<>>( - in.data(), out.data(), args); - }); - }); - }); - }); + // Fallback col reduce + col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); } } // namespace mlx::core