diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 213456692..80c5cc254 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -136,61 +136,8 @@ __global__ void row_reduce_small_warp( } } -template < - typename T, - typename U, - typename Op, - int NDIM, - int BLOCK_DIM_X, - int N_READS = 4> -__global__ void row_reduce_looped( - const T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - - size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; - if (out_idx >= out_size) { - return; - } - - Op op; - - U total_val = ReduceInit::value(); - LoopedElemToLoc 2)> loop(args.reduce_ndim); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - for (size_t n = 0; n < args.non_row_reductions; n++) { - for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS); - r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r * BLOCK_DIM_X + block.thread_index().x, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); - } - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - - typedef cub::BlockReduce BlockReduceT; - __shared__ typename BlockReduceT::TempStorage temp; - - total_val = BlockReduceT(temp).Reduce(total_val, op); - - if (block.thread_rank() == 0) { - out[out_idx] = total_val; - } -} - template -__global__ void -row_reduce_per_threadblock(T* in, U* out, size_t n_rows, int size) { +__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -274,6 +221,72 @@ row_reduce_per_threadblock(T* in, U* out, size_t n_rows, int size) { } } +template < + typename T, + typename U, + typename Op, + int NDIM, + int BLOCK_DIM_X, + int N_READS = 4> +__global__ void row_reduce_looped( + T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value(); + LoopedElemToLoc 2)> loop(args.reduce_ndim); + size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS); + size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS; + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (size_t r = 0; r < full_blocks; r++) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + loop.location() + r * BLOCK_DIM_X * N_READS, + vals); + for (int i = 0; i < N_READS; i++) { + total_val = op(total_val, __cast(vals[i])); + } + } + if (final_offset < args.row_size) { + T vals[N_READS]; + cub::LoadDirectBlocked( + block.thread_rank(), + in + loop.location() + final_offset, + vals, + args.row_size - final_offset, + __cast(ReduceInit::value())); + for (int i = 0; i < N_READS; i++) { + total_val = op(total_val, __cast(vals[i])); + } + } + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + total_val = BlockReduceT(temp).Reduce(total_val, op); + + if (block.thread_rank() == 0) { + out[out_idx] = total_val; + } +} + } // namespace cu void row_reduce_simple( @@ -287,7 +300,7 @@ void row_reduce_simple( // Initialize out such that its strides match in's layout (except the fastest // moving axis) - auto [_, out_strides] = shapes_without_reduction_axes(in, axes); + auto out_strides = in.strides(); for (auto& s : out_strides) { s /= plan.shape.back(); } @@ -321,12 +334,17 @@ void row_reduce_simple( size_t reductions = plan.shape.back() / N_READS; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; dim3 block(threads, 1, 1); - auto kernel = cu::row_reduce_per_threadblock; + + // Pick the kernel + auto kernel = cu::row_reduce_simple; if (grid.x >= 1024) { grid.x = (grid.x + 1) / 2; - kernel = cu::row_reduce_per_threadblock; + kernel = cu::row_reduce_simple; } + + // Launch kernel<<>>( x.data(), out.data(), out.size(), plan.shape.back()); }); @@ -334,6 +352,75 @@ void row_reduce_simple( }); } +void row_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + constexpr int N_READS = 8; + + // Initialize out such that it matches in's layout. Basically we keep any + // transpositions as it were and that allows us to skip finding the location + // of the output that matches the input. + auto out_strides = in.strides(); + for (auto ax : axes) { + for (auto& s : out_strides) { + if (s > in.strides(ax)) { + s /= in.shape(ax); + } + } + } + auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides); + auto fl = in.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = data_size == out.size(); + out.set_data( + allocator::malloc(out.nbytes()), + data_size, + out_strides, + fl, + allocator::free); + + // Just a way to get out of the constness because cub doesn't like it ... + // (sigh) + array x = in; + + encoder.set_input_array(x); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Calculate the grid and block dims + cu::RowReduceArgs args(in, plan, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + size_t reductions = args.row_size / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_looped; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + MLX_SWITCH_BLOCK_DIM(threads, THREADS, { + kernel = cu::row_reduce_looped; + block.x = THREADS; + }); + }); + + // Launch + kernel<<>>( + x.data(), out.data(), out.size(), args); + }); + }); + }); +} + void row_reduce( cu::CommandEncoder& encoder, const array& in, @@ -341,10 +428,14 @@ void row_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { + // Simple row reduce means that we have 1 axis that we are reducing over and + // it has stride 1. if (plan.shape.size() == 1) { row_reduce_simple(encoder, in, out, reduce_type, axes, plan); } - // cu::RowReduceArgs args(in, plan, axes); + + // Fallback row reduce + row_reduce_looped(encoder, in, out, reduce_type, axes, plan); // encoder.launch_kernel([&](cudaStream_t stream) { // MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {