From 664d8e42b840d7aae7811ac496694c83398ba847 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 21 Jun 2025 12:44:26 -0700 Subject: [PATCH] Add comments and clean up --- mlx/backend/cuda/reduce/row_reduce.cu | 305 ++------------------------ 1 file changed, 17 insertions(+), 288 deletions(-) diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 735de5311..12bf8897b 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -77,88 +77,6 @@ struct RowReduceArgs { } }; -// template -//__global__ void row_reduce_small( -// const T* in, -// U* out, -// size_t out_size, -// const __grid_constant__ RowReduceArgs args) { -// size_t out_idx = cg::this_grid().thread_rank(); -// if (out_idx >= out_size) { -// return; -// } -// -// Op op; -// -// U total_val = ReduceInit::value(); -// LoopedElemToLoc 2)> loop(args.reduce_ndim); -// -// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), -// args.ndim); -// -// for (size_t n = 0; n < args.non_row_reductions; n++) { -// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { -// U vals[N_READS]; -// cub::LoadDirectBlocked( -// r, -// make_cast_iterator(in + loop.location()), -// vals, -// args.row_size, -// ReduceInit::value()); -// total_val = op(total_val, cub::ThreadReduce(vals, op)); -// } -// loop.next(args.reduce_shape.data(), args.reduce_strides.data()); -// } -// -// out[out_idx] = total_val; -// } -// -// template -//__global__ void row_reduce_small_warp( -// const T* in, -// U* out, -// size_t out_size, -// const __grid_constant__ RowReduceArgs args) { -// auto grid = cg::this_grid(); -// auto block = cg::this_thread_block(); -// auto warp = cg::tiled_partition(block); -// -// size_t out_idx = grid.thread_rank() / WARP_SIZE; -// if (out_idx >= out_size) { -// return; -// } -// -// Op op; -// -// U total_val = ReduceInit::value(); -// LoopedElemToLoc 2)> loop(args.reduce_ndim); -// -// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), -// args.ndim); -// -// for (size_t n = warp.thread_rank(); n < args.non_row_reductions; -// n += WARP_SIZE) { -// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { -// U vals[N_READS]; -// cub::LoadDirectBlocked( -// r, -// make_cast_iterator(in + loop.location()), -// vals, -// args.row_size, -// ReduceInit::value()); -// total_val = op(total_val, cub::ThreadReduce(vals, op)); -// } -// loop.next(WARP_SIZE, args.reduce_shape.data(), -// args.reduce_strides.data()); -// } -// -// total_val = cg::reduce(warp, total_val, op); -// -// if (warp.thread_rank() == 0) { -// out[out_idx] = total_val; -// } -// } - template __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { auto grid = cg::this_grid(); @@ -286,100 +204,8 @@ __global__ void row_reduce_looped( } } -template -__global__ void reduce_initialize(U* out, size_t out_size) { - auto grid = cg::this_grid(); - if (grid.thread_rank() * N + N <= out_size) { - for (int i = 0; i < N; i++) { - out[grid.thread_rank() * N + i] = ReduceInit::value(); - } - } else { - for (int i = grid.thread_rank() * N; i < out_size; i++) { - out[i] = ReduceInit::value(); - } - } -} - -template -__global__ void row_reduce_atomics( - T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - size_t reduction_idx = grid.block_rank() / out_size; - size_t out_idx = grid.block_rank() % out_size; - - Op op; - - U total[1]; - U init = ReduceInit::value(); - total[0] = init; - size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); - size_t final_offset = full_blocks * BLOCK_DIM * N_READS; - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - in += elem_to_loc( - reduction_idx, - args.reduce_shape.data(), - args.reduce_strides.data(), - args.reduce_ndim); - - for (size_t r = 0; r < full_blocks; r++) { - T vals[N_READS]; - cub::LoadDirectBlockedVectorized( - block.thread_rank(), in + r * BLOCK_DIM * N_READS, vals); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], __cast(vals[i])); - } - } - if (final_offset < args.row_size) { - T vals[N_READS]; - cub::LoadDirectBlocked( - block.thread_rank(), - in + final_offset, - vals, - args.row_size - final_offset, - __cast(init)); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], __cast(vals[i])); - } - } - - __shared__ U shared_accumulators[32]; - block_reduce(block, warp, total, shared_accumulators, op, init); - - if (block.thread_rank() == 0) { - op.atomic_update(out + out_idx, total[0]); - } -} - } // namespace cu -void reduce_initialize( - cu::CommandEncoder& encoder, - array& out, - Reduce::ReduceType reduce_type) { - constexpr int N_WRITES = 8; - encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; - - auto kernel = cu::reduce_initialize; - auto [grid, block] = - get_launch_args(kernel, out, out.size() >= 1UL << 31, N_WRITES); - kernel<<>>(out.data(), out.size()); - }); - }); - }); -} - void row_reduce_simple( cu::CommandEncoder& encoder, const array& in, @@ -480,66 +306,6 @@ void row_reduce_looped( }); } -void row_reduce_atomics( - cu::CommandEncoder& encoder, - const array& in, - array& out, - Reduce::ReduceType reduce_type, - const std::vector& axes, - const ReductionPlan& plan, - cu::RowReduceArgs args) { - constexpr int N_READS = 8; - - // Allocate data for the output using in's layout to access them as - // contiguously as possible. - allocate_same_layout(out, in, axes); - - // Just a way to get out of the constness because cub doesn't like it ... - // (sigh) - array x = in; - - // Initialize - reduce_initialize(encoder, out, reduce_type); - - // Launch the reduction - encoder.set_input_array(x); - encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; - - args.convert_shapes_to_contiguous(x, axes); - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - if (grid.x * args.non_row_reductions < INT_MAX) { - grid.x *= args.non_row_reductions; - } else if (grid.y * args.non_row_reductions < 65536) { - grid.y *= args.non_row_reductions; - } else { - throw std::runtime_error( - "[row_reduce_atomics] Non-row reductions need to be factorized which is NYI"); - } - size_t reductions = args.row_size / N_READS; - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - dim3 block(threads, 1, 1); - - // Pick the kernel - auto kernel = cu::row_reduce_atomics; - MLX_SWITCH_BLOCK_DIM(threads, THREADS, { - kernel = cu::row_reduce_atomics; - block.x = THREADS; - }); - - // Launch - kernel<<>>( - x.data(), out.data(), out.size(), args); - }); - }); - }); -} - void row_reduce( cu::CommandEncoder& encoder, const array& in, @@ -547,6 +313,23 @@ void row_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { + // Current row reduction options + // + // - row_reduce_simple + // + // That means that we are simply reducing across the fastest moving axis. + // We are reducing 1 or 2 rows per threadblock depending on the size of + // output. + // + // - row_reduce_looped + // + // It is a general row reduction. We are computing 1 output per + // threadblock. We read the fastest moving axis vectorized and loop over + // the rest of the axes. + // + // Notes: We opt to read as much in order as possible and leave + // transpositions as they are (contrary to our Metal backend). + // Simple row reduce means that we have 1 axis that we are reducing over and // it has stride 1. if (plan.shape.size() == 1) { @@ -557,62 +340,8 @@ void row_reduce( // Make the args struct to help route to the best kernel cu::RowReduceArgs args(in, plan, axes); - // Let's use atomics to increase parallelism - if (false && args.row_size < 512) { - row_reduce_atomics( - encoder, in, out, reduce_type, axes, plan, std::move(args)); - } - // Fallback row reduce row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); - - // encoder.launch_kernel([&](cudaStream_t stream) { - // MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - // using InType = cuda_type_t; - // MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - // using OutType = cu::ReduceResult::type; - // MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - // constexpr size_t N_READS = 4; - // dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); - // dim3 block_dims, num_blocks; - // auto kernel = - // cu::row_reduce_small; - // if (args.row_size <= 64) { - // if ((args.non_row_reductions < 32 && args.row_size <= 8) || - // (args.non_row_reductions <= 8)) { - // block_dims.x = std::min(out_dims.x, 1024u); - // num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); - // num_blocks.y = out_dims.y; - // } else { - // block_dims.x = WARP_SIZE; - // num_blocks.y = out_dims.x; - // num_blocks.z = out_dims.y; - // kernel = - // cu::row_reduce_small_warp; - // } - // } else { - // size_t num_threads = cuda::ceil_div(args.row_size, N_READS); - // num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; - // MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { - // num_blocks.y = out_dims.x; - // num_blocks.z = out_dims.y; - // block_dims.x = BLOCK_DIM_X; - // kernel = cu::row_reduce_looped< - // InType, - // OutType, - // OP, - // NDIM, - // BLOCK_DIM_X, - // N_READS>; - // }); - // } - // kernel<<>>( - // in.data(), out.data(), out.size(), args); - // }); - // }); - // }); - // }); } } // namespace mlx::core