From bc60a31cae81637dfb914c2fa4ed818f19461fef Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Jun 2025 22:54:09 -0700 Subject: [PATCH] Comments --- mlx/backend/cuda/reduce/all_reduce.cu | 29 ++++++++++++++++-------- mlx/backend/cuda/reduce/col_reduce.cu | 28 +++++++++++++---------- mlx/backend/cuda/reduce/init_reduce.cu | 1 - mlx/backend/cuda/reduce/reduce_utils.cuh | 2 +- mlx/backend/cuda/reduce/row_reduce.cu | 28 +++++++++++------------ 5 files changed, 49 insertions(+), 39 deletions(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 0467da104..5a7c28041 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -96,43 +96,52 @@ void all_reduce( int blocks, threads; size_t block_step; - array x = in; + size_t insize = in.size(); + Dtype dt = in.dtype(); + + // Cub doesn't like const pointers for load (sigh). + void* indata = const_cast(in.data()); // Large array so allocate an intermediate and accumulate there - std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(in); if (blocks > 1) { array intermediate({blocks}, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); encoder.add_temporary(intermediate); - encoder.set_input_array(x); encoder.set_output_array(intermediate); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_ALL_TYPES(dt, CTYPE, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using T = cuda_type_t; using U = cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( - x.data(), intermediate.data(), block_step, x.size()); + static_cast(indata), + intermediate.data(), + block_step, + insize); }); }); }); // Set the input for the next step and recalculate the blocks - x = intermediate; - std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); + indata = intermediate.data(); + dt = intermediate.dtype(); + insize = intermediate.size(); + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(intermediate); } - encoder.set_input_array(x); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_ALL_TYPES(dt, CTYPE, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using T = cuda_type_t; using U = cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( - x.data(), out.data(), block_step, x.size()); + static_cast(indata), out.data(), block_step, insize); }); }); }); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 8dbebe386..192a9b3e8 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -38,12 +38,18 @@ struct ColReduceArgs { const array& in, const ReductionPlan& plan, const std::vector& axes) { + using ShapeVector = decltype(plan.shape); + using StridesVector = decltype(plan.strides); + + ShapeVector shape_vec; + StridesVector strides_vec; + assert(!plan.shape.empty()); reduction_size = plan.shape.back(); reduction_stride = plan.strides.back(); int64_t stride_back = 1; - auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); while (!shape_vec.empty() && stride_back < reduction_stride) { stride_back *= shape_vec.back(); shape_vec.pop_back(); @@ -54,8 +60,8 @@ struct ColReduceArgs { std::sort(indices.begin(), indices.end(), [&](int left, int right) { return strides_vec[left] > strides_vec[right]; }); - decltype(shape_vec) sorted_shape; - decltype(strides_vec) sorted_strides; + ShapeVector sorted_shape; + StridesVector sorted_strides; for (auto idx : indices) { sorted_shape.push_back(shape_vec[idx]); sorted_strides.push_back(strides_vec[idx]); @@ -206,26 +212,25 @@ void col_reduce_looped( // contiguously as possible. allocate_same_layout(out, in, axes); - // Just a way to get out of the constness because cub doesn't like it ... - // (sigh) - array x = in; - - encoder.set_input_array(x); + encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { using T = cuda_type_t; using U = cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + constexpr int N_READS = 4; constexpr int BM = 32; constexpr int BN = 32; dim3 grid = output_grid_for_col_reduce(out, args, BN); int blocks = BM * BN / N_READS; auto kernel = cu::col_reduce_looped; - kernel<<>>(x.data(), out.data(), args); + kernel<<>>(indata, out.data(), args); }); }); }); @@ -247,8 +252,7 @@ void col_reduce( // a subrow of the fast moving axis. For instance 32 elements. // // Notes: As in row reduce we opt to read as much in order as possible and - // leave - // transpositions as they are (contrary to our Metal backend). + // leave transpositions as they are (contrary to our Metal backend). // // Moreover we need different kernels for short rows and tuning diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index a500dc04e..50fe109c4 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -31,7 +31,6 @@ void init_reduce( out.set_data(allocator::malloc(out.nbytes())); } - encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index b76411261..d4670503a 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -146,7 +146,7 @@ inline void allocate_same_layout( auto fl = in.flags(); fl.row_contiguous = rc; fl.col_contiguous = cc; - fl.contiguous = data_size == out.size(); + fl.contiguous = true; out.set_data( allocator::malloc(out.nbytes()), data_size, diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 7e155795a..6a8a35311 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -241,20 +241,19 @@ void row_reduce_simple( // kernel. allocate_same_layout(out, in, axes); - // Just a way to get out of the constness because cub doesn't like it ... - // (sigh) - array x = in; - // TODO: If out.size() < 1024 which will be a common case then write this in // 2 passes. Something like 32 * out.size() and then do a warp reduce. - encoder.set_input_array(x); + encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using T = cuda_type_t; using U = cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + // Calculate the grid and block dims size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); @@ -271,7 +270,7 @@ void row_reduce_simple( // Launch kernel<<>>( - x.data(), out.data(), out.size(), plan.shape.back()); + indata, out.data(), out.size(), plan.shape.back()); }); }); }); @@ -291,20 +290,19 @@ void row_reduce_looped( // contiguously as possible. allocate_same_layout(out, in, axes); - // Just a way to get out of the constness because cub doesn't like it ... - // (sigh) - array x = in; - - encoder.set_input_array(x); + encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using T = cuda_type_t; using U = cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + // Calculate the grid and block dims - args.sort_access_pattern(x, axes); + args.sort_access_pattern(in, axes); dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); size_t reductions = (args.row_size + N_READS - 1) / N_READS; int threads = std::min(1024UL, reductions); @@ -322,7 +320,7 @@ void row_reduce_looped( // Launch kernel<<>>( - x.data(), out.data(), out.size(), args); + indata, out.data(), out.size(), args); }); }); });