From abdb21f27c75355118ec1567b549a2b4d7df1f60 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 21 Jun 2025 12:37:35 -0700 Subject: [PATCH] Add helpers and atomic kernel --- mlx/backend/common/reduce.cpp | 15 +- mlx/backend/common/reduce.h | 4 + mlx/backend/cuda/reduce/all_reduce.cu | 31 +- mlx/backend/cuda/reduce/reduce_utils.cuh | 73 +++- mlx/backend/cuda/reduce/row_reduce.cu | 467 +++++++++++++++-------- 5 files changed, 394 insertions(+), 196 deletions(-) diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 5c7f63b75..ceef46400 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -5,11 +5,9 @@ namespace mlx::core { std::pair shapes_without_reduction_axes( - const array& x, + Shape shape, + Strides strides, const std::vector& axes) { - auto shape = x.shape(); - auto strides = x.strides(); - for (int i = axes.size() - 1; i >= 0; i--) { int a = axes[i]; shape.erase(shape.begin() + a); @@ -19,6 +17,15 @@ std::pair shapes_without_reduction_axes( return std::make_pair(shape, strides); } +std::pair shapes_without_reduction_axes( + const array& x, + const std::vector& axes) { + auto shape = x.shape(); + auto strides = x.strides(); + return shapes_without_reduction_axes( + std::move(shape), std::move(strides), axes); +} + ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index ddb5c3492..8b24f4f53 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes); +std::pair shapes_without_reduction_axes( + Shape shape, + Strides strides, + const std::vector& axes); } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 557c973f3..54dd351bb 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -15,6 +15,9 @@ namespace cg = cooperative_groups; template __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + // TODO: Process multiple "rows" in each thread + constexpr int M = 1; + auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -23,10 +26,8 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { ReduceOp op; T vals[N]; - U accs[N]; - for (int i = 0; i < N; i++) { - accs[i] = init; - } + U accs[M]; + accs[0] = init; size_t start = grid.block_rank() * block_step; size_t end = start + block_step; @@ -35,7 +36,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { - accs[j] = op(accs[j], __cast(vals[j])); + accs[0] = op(accs[0], __cast(vals[j])); } } @@ -45,26 +46,12 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlocked( block.thread_rank(), in + offset, vals, block_end, __cast(init)); for (int i = 0; i < N; i++) { - accs[i] = op(accs[i], __cast(vals[i])); + accs[0] = op(accs[0], __cast(vals[i])); } } - for (int i = 1; i < N; i++) { - accs[0] = op(accs[0], accs[i]); - } - accs[0] = cg::reduce(warp, accs[0], op); - - if (warp.meta_group_size() > 1) { - __shared__ U shared_accumulators[32]; - if (warp.thread_rank() == 0) { - shared_accumulators[warp.meta_group_rank()] = accs[0]; - } - block.sync(); - accs[0] = (warp.thread_rank() < warp.meta_group_size()) - ? shared_accumulators[warp.thread_rank()] - : init; - accs[0] = cg::reduce(warp, accs[0], op); - } + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, accs, shared_accumulators, op, init); if (block.thread_rank() == 0) { out[grid.block_rank()] = accs[0]; diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index 2d83bd366..5a3d09bf5 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -4,7 +4,14 @@ #include "mlx/backend/cuda/device/utils.cuh" -namespace mlx::core::cu { +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; template struct uint_by_size; @@ -62,4 +69,66 @@ inline __device__ cuComplex __cast(bool x) { return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); } -} // namespace mlx::core::cu +template +inline __device__ void +block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { + // First reduce in the current warp + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + + // Reduce across warps + if (warp.meta_group_size() > 1) { + if (warp.thread_rank() == 0) { + for (int i = 0; i < N; i++) { + smem[warp.meta_group_rank() * N + i] = vals[i]; + } + } + block.sync(); + if (warp.thread_rank() < warp.meta_group_size()) { + for (int i = 0; i < N; i++) { + vals[i] = smem[warp.thread_rank() * N + i]; + } + } else { + for (int i = 0; i < N; i++) { + vals[i] = init; + } + } + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + } +} + +} // namespace cu + +inline void allocate_same_layout( + array& out, + const array& in, + const std::vector& axes) { + // Initialize out such that it matches in's layout. Basically we keep any + // transpositions as it were and that allows us either to skip finding the + // location of the output that matches the input or simply contiguous read or + // writes. + auto out_strides = in.strides(); + for (auto ax : axes) { + for (auto& s : out_strides) { + if (s > in.strides(ax)) { + s /= in.shape(ax); + } + } + } + auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides); + auto fl = in.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = data_size == out.size(); + out.set_data( + allocator::malloc(out.nbytes()), + data_size, + out_strides, + fl, + allocator::free); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 0af6f27cc..735de5311 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -55,86 +55,109 @@ struct RowReduceArgs { non_row_reductions *= reduce_shape[i]; } } + + // Convert shape and strides as if in was contiguous + void convert_shapes_to_contiguous( + const array& in, + const std::vector& axes) { + auto shape_vec = in.shape(); + auto strides_vec = in.strides(); + size_t s = 1; + for (int i = in.ndim() - 1; i >= 0; i--) { + strides_vec[i] = s; + s *= shape_vec[i]; + } + std::tie(shape_vec, strides_vec) = + shapes_without_reduction_axes(shape_vec, strides_vec, axes); + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + } }; -template -__global__ void row_reduce_small( - const T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { - size_t out_idx = cg::this_grid().thread_rank(); - if (out_idx >= out_size) { - return; - } - - Op op; - - U total_val = ReduceInit::value(); - LoopedElemToLoc 2)> loop(args.reduce_ndim); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - for (size_t n = 0; n < args.non_row_reductions; n++) { - for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); - } - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - - out[out_idx] = total_val; -} - -template -__global__ void row_reduce_small_warp( - const T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - size_t out_idx = grid.thread_rank() / WARP_SIZE; - if (out_idx >= out_size) { - return; - } - - Op op; - - U total_val = ReduceInit::value(); - LoopedElemToLoc 2)> loop(args.reduce_ndim); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - for (size_t n = warp.thread_rank(); n < args.non_row_reductions; - n += WARP_SIZE) { - for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); - } - loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data()); - } - - total_val = cg::reduce(warp, total_val, op); - - if (warp.thread_rank() == 0) { - out[out_idx] = total_val; - } -} +// template +//__global__ void row_reduce_small( +// const T* in, +// U* out, +// size_t out_size, +// const __grid_constant__ RowReduceArgs args) { +// size_t out_idx = cg::this_grid().thread_rank(); +// if (out_idx >= out_size) { +// return; +// } +// +// Op op; +// +// U total_val = ReduceInit::value(); +// LoopedElemToLoc 2)> loop(args.reduce_ndim); +// +// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), +// args.ndim); +// +// for (size_t n = 0; n < args.non_row_reductions; n++) { +// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { +// U vals[N_READS]; +// cub::LoadDirectBlocked( +// r, +// make_cast_iterator(in + loop.location()), +// vals, +// args.row_size, +// ReduceInit::value()); +// total_val = op(total_val, cub::ThreadReduce(vals, op)); +// } +// loop.next(args.reduce_shape.data(), args.reduce_strides.data()); +// } +// +// out[out_idx] = total_val; +// } +// +// template +//__global__ void row_reduce_small_warp( +// const T* in, +// U* out, +// size_t out_size, +// const __grid_constant__ RowReduceArgs args) { +// auto grid = cg::this_grid(); +// auto block = cg::this_thread_block(); +// auto warp = cg::tiled_partition(block); +// +// size_t out_idx = grid.thread_rank() / WARP_SIZE; +// if (out_idx >= out_size) { +// return; +// } +// +// Op op; +// +// U total_val = ReduceInit::value(); +// LoopedElemToLoc 2)> loop(args.reduce_ndim); +// +// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), +// args.ndim); +// +// for (size_t n = warp.thread_rank(); n < args.non_row_reductions; +// n += WARP_SIZE) { +// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { +// U vals[N_READS]; +// cub::LoadDirectBlocked( +// r, +// make_cast_iterator(in + loop.location()), +// vals, +// args.row_size, +// ReduceInit::value()); +// total_val = op(total_val, cub::ThreadReduce(vals, op)); +// } +// loop.next(WARP_SIZE, args.reduce_shape.data(), +// args.reduce_strides.data()); +// } +// +// total_val = cg::reduce(warp, total_val, op); +// +// if (warp.thread_rank() == 0) { +// out[out_idx] = total_val; +// } +// } template __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { @@ -153,59 +176,37 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { const size_t start_row = min(n_rows - M, static_cast(grid.block_rank() * M)); + const size_t full_blocks = size / (block.size() * N); + const size_t final_offset = full_blocks * (block.size() * N); in += start_row * size; out += start_row; - int i = 0; - for (; i + block.size() * N <= size; i += block.size() * N) { + for (size_t r = 0; r < full_blocks; r++) { for (int k = 0; k < M; k++) { cub::LoadDirectBlockedVectorized( - block.thread_rank(), in + k * size + i, vals[k]); + block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]); for (int j = 0; j < N; j++) { accs[k] = op(accs[k], __cast(vals[k][j])); } } } - if (size > i) { + if (final_offset < size) { for (int k = 0; k < M; k++) { cub::LoadDirectBlocked( block.thread_rank(), - in + k * size + i, + in + k * size + final_offset, vals[k], size, __cast(init)); - for (int j = 0; i < N; i++) { + for (int j = 0; j < N; j++) { accs[k] = op(accs[k], __cast(vals[k][j])); } } } - for (int i = 0; i < M; i++) { - accs[i] = cg::reduce(warp, accs[i], op); - } - - if (warp.meta_group_size() > 1) { - __shared__ U shared_accumulators[32 * M]; - if (warp.thread_rank() == 0) { - for (int i = 0; i < M; i++) { - shared_accumulators[warp.meta_group_rank() * M + i] = accs[i]; - } - } - block.sync(); - if (warp.thread_rank() < warp.meta_group_size()) { - for (int i = 0; i < M; i++) { - accs[i] = shared_accumulators[warp.thread_rank() * M + i]; - } - } else { - for (int i = 0; i < M; i++) { - accs[i] = init; - } - } - for (int i = 0; i < M; i++) { - accs[i] = cg::reduce(warp, accs[i], op); - } - } + __shared__ U shared_accumulators[32 * M]; + block_reduce(block, warp, accs, shared_accumulators, op, init); if (block.thread_rank() == 0) { if (grid.block_rank() * M + M <= n_rows) { @@ -226,7 +227,7 @@ template < typename U, typename Op, int NDIM, - int BLOCK_DIM_X, + int BLOCK_DIM, int N_READS = 4> __global__ void row_reduce_looped( T* in, @@ -237,27 +238,28 @@ __global__ void row_reduce_looped( auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; - if (out_idx >= out_size) { - return; - } + size_t out_idx = grid.block_rank(); Op op; - U total_val = ReduceInit::value(); + U total[1]; + U init = ReduceInit::value(); + total[0] = init; LoopedElemToLoc 2)> loop(args.reduce_ndim); - size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS); - size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS; + size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); + size_t final_offset = full_blocks * BLOCK_DIM * N_READS; + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); for (size_t n = 0; n < args.non_row_reductions; n++) { for (size_t r = 0; r < full_blocks; r++) { T vals[N_READS]; cub::LoadDirectBlockedVectorized( block.thread_rank(), - in + loop.location() + r * BLOCK_DIM_X * N_READS, + in + loop.location() + r * BLOCK_DIM * N_READS, vals); for (int i = 0; i < N_READS; i++) { - total_val = op(total_val, __cast(vals[i])); + total[0] = op(total[0], __cast(vals[i])); } } if (final_offset < args.row_size) { @@ -267,26 +269,117 @@ __global__ void row_reduce_looped( in + loop.location() + final_offset, vals, args.row_size - final_offset, - __cast(ReduceInit::value())); + __cast(init)); for (int i = 0; i < N_READS; i++) { - total_val = op(total_val, __cast(vals[i])); + total[0] = op(total[0], __cast(vals[i])); } } + // TODO: Maybe block.sync() here? loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } - typedef cub::BlockReduce BlockReduceT; - __shared__ typename BlockReduceT::TempStorage temp; - - total_val = BlockReduceT(temp).Reduce(total_val, op); + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, total, shared_accumulators, op, init); if (block.thread_rank() == 0) { - out[out_idx] = total_val; + out[out_idx] = total[0]; + } +} + +template +__global__ void reduce_initialize(U* out, size_t out_size) { + auto grid = cg::this_grid(); + if (grid.thread_rank() * N + N <= out_size) { + for (int i = 0; i < N; i++) { + out[grid.thread_rank() * N + i] = ReduceInit::value(); + } + } else { + for (int i = grid.thread_rank() * N; i < out_size; i++) { + out[i] = ReduceInit::value(); + } + } +} + +template +__global__ void row_reduce_atomics( + T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t reduction_idx = grid.block_rank() / out_size; + size_t out_idx = grid.block_rank() % out_size; + + Op op; + + U total[1]; + U init = ReduceInit::value(); + total[0] = init; + size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); + size_t final_offset = full_blocks * BLOCK_DIM * N_READS; + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + in += elem_to_loc( + reduction_idx, + args.reduce_shape.data(), + args.reduce_strides.data(), + args.reduce_ndim); + + for (size_t r = 0; r < full_blocks; r++) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized( + block.thread_rank(), in + r * BLOCK_DIM * N_READS, vals); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], __cast(vals[i])); + } + } + if (final_offset < args.row_size) { + T vals[N_READS]; + cub::LoadDirectBlocked( + block.thread_rank(), + in + final_offset, + vals, + args.row_size - final_offset, + __cast(init)); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], __cast(vals[i])); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, total, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + op.atomic_update(out + out_idx, total[0]); } } } // namespace cu +void reduce_initialize( + cu::CommandEncoder& encoder, + array& out, + Reduce::ReduceType reduce_type) { + constexpr int N_WRITES = 8; + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + auto kernel = cu::reduce_initialize; + auto [grid, block] = + get_launch_args(kernel, out, out.size() >= 1UL << 31, N_WRITES); + kernel<<>>(out.data(), out.size()); + }); + }); + }); +} + void row_reduce_simple( cu::CommandEncoder& encoder, const array& in, @@ -296,23 +389,9 @@ void row_reduce_simple( const ReductionPlan& plan) { constexpr int N_READS = 8; - // Initialize out such that its strides match in's layout (except the fastest - // moving axis) - auto out_strides = in.strides(); - for (auto& s : out_strides) { - s /= plan.shape.back(); - } - auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides); - auto fl = in.flags(); - fl.row_contiguous = rc; - fl.col_contiguous = cc; - fl.contiguous = data_size == out.size(); - out.set_data( - allocator::malloc(out.nbytes()), - data_size, - out_strides, - fl, - allocator::free); + // Allocate data for the output using in's layout to avoid elem_to_loc in the + // kernel. + allocate_same_layout(out, in, axes); // Just a way to get out of the constness because cub doesn't like it ... // (sigh) @@ -356,31 +435,13 @@ void row_reduce_looped( array& out, Reduce::ReduceType reduce_type, const std::vector& axes, - const ReductionPlan& plan) { + const ReductionPlan& plan, + cu::RowReduceArgs args) { constexpr int N_READS = 8; - // Initialize out such that it matches in's layout. Basically we keep any - // transpositions as it were and that allows us to skip finding the location - // of the output that matches the input. - auto out_strides = in.strides(); - for (auto ax : axes) { - for (auto& s : out_strides) { - if (s > in.strides(ax)) { - s /= in.shape(ax); - } - } - } - auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides); - auto fl = in.flags(); - fl.row_contiguous = rc; - fl.col_contiguous = cc; - fl.contiguous = data_size == out.size(); - out.set_data( - allocator::malloc(out.nbytes()), - data_size, - out_strides, - fl, - allocator::free); + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); // Just a way to get out of the constness because cub doesn't like it ... // (sigh) @@ -395,7 +456,7 @@ void row_reduce_looped( using U = cu::ReduceResult::type; // Calculate the grid and block dims - cu::RowReduceArgs args(in, plan, axes); + args.convert_shapes_to_contiguous(x, axes); dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); size_t reductions = args.row_size / N_READS; int threads = std::min(1024UL, reductions); @@ -419,6 +480,66 @@ void row_reduce_looped( }); } +void row_reduce_atomics( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::RowReduceArgs args) { + constexpr int N_READS = 8; + + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + // Just a way to get out of the constness because cub doesn't like it ... + // (sigh) + array x = in; + + // Initialize + reduce_initialize(encoder, out, reduce_type); + + // Launch the reduction + encoder.set_input_array(x); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + args.convert_shapes_to_contiguous(x, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + if (grid.x * args.non_row_reductions < INT_MAX) { + grid.x *= args.non_row_reductions; + } else if (grid.y * args.non_row_reductions < 65536) { + grid.y *= args.non_row_reductions; + } else { + throw std::runtime_error( + "[row_reduce_atomics] Non-row reductions need to be factorized which is NYI"); + } + size_t reductions = args.row_size / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_atomics; + MLX_SWITCH_BLOCK_DIM(threads, THREADS, { + kernel = cu::row_reduce_atomics; + block.x = THREADS; + }); + + // Launch + kernel<<>>( + x.data(), out.data(), out.size(), args); + }); + }); + }); +} + void row_reduce( cu::CommandEncoder& encoder, const array& in, @@ -430,10 +551,20 @@ void row_reduce( // it has stride 1. if (plan.shape.size() == 1) { row_reduce_simple(encoder, in, out, reduce_type, axes, plan); + return; + } + + // Make the args struct to help route to the best kernel + cu::RowReduceArgs args(in, plan, axes); + + // Let's use atomics to increase parallelism + if (false && args.row_size < 512) { + row_reduce_atomics( + encoder, in, out, reduce_type, axes, plan, std::move(args)); } // Fallback row reduce - row_reduce_looped(encoder, in, out, reduce_type, axes, plan); + row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); // encoder.launch_kernel([&](cudaStream_t stream) { // MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {