From 772f471ff265ad21996565161fa48811b9ed6b91 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Jun 2025 12:59:20 -0700 Subject: [PATCH] [CUDA] Fix reductions (#2314) --- benchmarks/python/comparative/bench_torch.py | 22 +- mlx/backend/common/reduce.cpp | 15 +- mlx/backend/common/reduce.h | 4 + mlx/backend/cuda/CMakeLists.txt | 3 +- mlx/backend/cuda/binary_two.cu | 2 +- mlx/backend/cuda/reduce.cu | 38 +- mlx/backend/cuda/reduce/all_reduce.cu | 150 ++++++++ mlx/backend/cuda/reduce/col_reduce.cu | 296 +++++++------- mlx/backend/cuda/reduce/init_reduce.cu | 50 +++ mlx/backend/cuda/reduce/reduce.cuh | 12 +- mlx/backend/cuda/reduce/reduce_ops.cuh | 55 ++- mlx/backend/cuda/reduce/reduce_utils.cuh | 158 ++++++++ mlx/backend/cuda/reduce/row_reduce.cu | 383 ++++++++++++------- mlx/backend/cuda/reduce/segmented_reduce.cu | 84 ---- mlx/backend/cuda/softmax.cu | 4 +- python/tests/cuda_skip.py | 5 - 16 files changed, 862 insertions(+), 419 deletions(-) create mode 100644 mlx/backend/cuda/reduce/all_reduce.cu create mode 100644 mlx/backend/cuda/reduce/init_reduce.cu create mode 100644 mlx/backend/cuda/reduce/reduce_utils.cuh delete mode 100644 mlx/backend/cuda/reduce/segmented_reduce.cu diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..dd3436d9a 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -5,6 +5,7 @@ import os import time import torch +import torch.cuda import torch.mps @@ -44,8 +45,10 @@ def bench(f, *args): def sync_if_needed(x): - if x.device != torch.device("cpu"): + if x.device == torch.device("mps"): torch.mps.synchronize() + elif x.device == torch.device("cuda"): + torch.cuda.synchronize() @torch.no_grad() @@ -99,6 +102,14 @@ def reduction(op, axis, x): sync_if_needed(x) +@torch.no_grad() +def sum_and_add(axis, x, y): + z = x.sum(axis=axis, keepdims=True) + for i in range(50): + z = (z + y).sum(axis=axis, keepdims=True) + sync_if_needed(x) + + @torch.no_grad() def softmax(axis, x): ys = [] @@ -340,7 +351,11 @@ if __name__ == "__main__": args.axis.pop(0) torch.set_num_threads(1) - device = "cpu" if args.cpu else "mps" + device = "mps" + if torch.cuda.is_available(): + device = "cuda" + if args.cpu: + device = "cpu" types = args.dtype if not types: @@ -460,5 +475,8 @@ if __name__ == "__main__": elif args.benchmark == "selu": print(bench(selu, x)) + elif args.benchmark == "sum_and_add": + print(bench(sum_and_add, axis, *xs)) + else: raise ValueError(f"Unknown benchmark `{args.benchmark}`.") diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 5c7f63b75..ceef46400 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -5,11 +5,9 @@ namespace mlx::core { std::pair shapes_without_reduction_axes( - const array& x, + Shape shape, + Strides strides, const std::vector& axes) { - auto shape = x.shape(); - auto strides = x.strides(); - for (int i = axes.size() - 1; i >= 0; i--) { int a = axes[i]; shape.erase(shape.begin() + a); @@ -19,6 +17,15 @@ std::pair shapes_without_reduction_axes( return std::make_pair(shape, strides); } +std::pair shapes_without_reduction_axes( + const array& x, + const std::vector& axes) { + auto shape = x.shape(); + auto strides = x.strides(); + return shapes_without_reduction_axes( + std::move(shape), std::move(strides), axes); +} + ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index ddb5c3492..8b24f4f53 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes); +std::pair shapes_without_reduction_axes( + Shape shape, + Strides strides, + const std::vector& axes); } // namespace mlx::core diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index ad979a13f..8130d396f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -29,9 +29,10 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu - ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 3047e39f0..074c947da 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -157,7 +157,7 @@ void binary_op_gpu_inplace( if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { auto kernel = - &cu::binary_g_nd; + cu::binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large); kernel<<>>( diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index a740113db..8350eebb7 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -21,28 +21,11 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(!axes_.empty()); assert(out.size() != in.size()); - out.set_data(allocator::malloc(out.nbytes())); - auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - encoder.set_input_array(in); - encoder.set_output_array(out); - // Fill out with init value. if (in.size() == 0) { - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, { - using InType = cuda_type_t; - using OutType = cu::ReduceResult::type; - thrust::fill_n( - cu::thrust_policy(stream), - thrust::device_pointer_cast(out.data()), - out.data_size(), - cu::ReduceInit::value()); - }); - }); - }); + init_reduce(encoder, in, out, reduce_type_); return; } @@ -51,7 +34,19 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // If it is a general reduce then copy the input to a contiguous array and // recompute the plan. - if (plan.type == GeneralReduce) { + // + // TODO: Instead of copying we can use elem-to-loc to deal with broadcasting + // like we do in Metal. When it comes to broadcasted reduction axes + // some can be ignored eg for min/max. + bool broadcasted = false; + for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + broadcasted = in.strides(i) == 0; + } + } + if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { array in_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, in_copy, CopyType::General, s); encoder.add_temporary(in_copy); @@ -59,9 +54,8 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { plan = get_reduction_plan(in, axes_); } - if ((plan.type == ContiguousAllReduce) || - (plan.type == ContiguousReduce && plan.shape.size() == 1)) { - segmented_reduce(encoder, in, out, reduce_type_, axes_, plan); + if (plan.type == ContiguousAllReduce) { + all_reduce(encoder, in, out, reduce_type_); return; } diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu new file mode 100644 index 000000000..5a7c28041 --- /dev/null +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -0,0 +1,150 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + // TODO: Process multiple "rows" in each thread + constexpr int M = 1; + + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[N]; + U accs[M]; + accs[0] = init; + + size_t start = grid.block_rank() * block_step; + size_t end = start + block_step; + size_t check = min(end, size); + + size_t i = start; + for (; i + block.size() * N <= check; i += block.size() * N) { + cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); + for (int j = 0; j < N; j++) { + accs[0] = op(accs[0], __cast(vals[j])); + } + } + + if (i < check) { + cub::LoadDirectBlocked( + block.thread_rank(), in + i, vals, check - i, __cast(init)); + for (int i = 0; i < N; i++) { + accs[0] = op(accs[0], __cast(vals[i])); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[grid.block_rank()] = accs[0]; + } +} + +} // namespace cu + +void all_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + constexpr int N_READS = 8; + + out.set_data(allocator::malloc(out.nbytes())); + + auto get_args = [](size_t size, int N) { + int threads = std::min(512UL, (size + N - 1) / N); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = + (size + reductions_per_step - 1) / reductions_per_step; + + int blocks; + if (steps_needed < 32) { + blocks = 1; + } else if (steps_needed < 128) { + blocks = 32; + } else if (steps_needed < 512) { + blocks = 128; + } else if (steps_needed < 1024) { + blocks = 512; + } else { + blocks = 1024; + } + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; + + return std::make_tuple(blocks, threads, block_step); + }; + + int blocks, threads; + size_t block_step; + size_t insize = in.size(); + Dtype dt = in.dtype(); + + // Cub doesn't like const pointers for load (sigh). + void* indata = const_cast(in.data()); + + // Large array so allocate an intermediate and accumulate there + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(in); + if (blocks > 1) { + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + encoder.add_temporary(intermediate); + encoder.set_output_array(intermediate); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(dt, CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::all_reduce; + kernel<<>>( + static_cast(indata), + intermediate.data(), + block_step, + insize); + }); + }); + }); + + // Set the input for the next step and recalculate the blocks + indata = intermediate.data(); + dt = intermediate.dtype(); + insize = intermediate.size(); + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(intermediate); + } + + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(dt, CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::all_reduce; + kernel<<>>( + static_cast(indata), out.data(), block_step, insize); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 9911a6fe0..192a9b3e8 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" @@ -36,19 +38,36 @@ struct ColReduceArgs { const array& in, const ReductionPlan& plan, const std::vector& axes) { + using ShapeVector = decltype(plan.shape); + using StridesVector = decltype(plan.strides); + + ShapeVector shape_vec; + StridesVector strides_vec; + assert(!plan.shape.empty()); reduction_size = plan.shape.back(); reduction_stride = plan.strides.back(); int64_t stride_back = 1; - auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); while (!shape_vec.empty() && stride_back < reduction_stride) { stride_back *= shape_vec.back(); shape_vec.pop_back(); strides_vec.pop_back(); } + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + ShapeVector sorted_shape; + StridesVector sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(shape_vec, strides_vec); + collapse_contiguous_dims(sorted_shape, sorted_strides); shape = const_param(shape_vec); strides = const_param(strides_vec); ndim = shape_vec.size(); @@ -64,86 +83,6 @@ struct ColReduceArgs { } }; -template -__global__ void col_reduce_small( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - - int column = - grid.block_index().x * block.dim_threads().x + block.thread_index().x; - if (column * N_READS >= args.reduction_stride) { - return; - } - - int out_idx = grid.block_rank() / grid.dim_blocks().x; - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - Op op; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = ReduceInit::value(); - } - - // Read input to local. - LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next( - block.thread_index().y, - args.reduce_shape.data(), - args.reduce_strides.data()); - for (size_t r = block.thread_index().y; - r < args.non_col_reductions * args.reduction_size; - r += block.dim_threads().y) { - U vals[N_READS]; - cub::LoadDirectBlocked( - column, - make_cast_iterator(in + loop.location()), - vals, - args.reduction_stride, - ReduceInit::value()); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); - } - loop.next( - block.dim_threads().y, - args.reduce_shape.data(), - args.reduce_strides.data()); - } - - // Do block reduce when each column has more than 1 element to reduce. - if (block.dim_threads().y > 1) { - __shared__ U shared_vals[32 * 8 * N_READS]; - size_t col = - block.thread_index().y * block.dim_threads().x + block.thread_index().x; - for (int i = 0; i < N_READS; i++) { - shared_vals[col * N_READS + i] = totals[i]; - } - block.sync(); - if (block.thread_index().y == 0) { - for (int i = 0; i < N_READS; i++) { - totals[i] = shared_vals[block.thread_index().x * N_READS + i]; - } - for (int j = 1; j < block.dim_threads().y; j++) { - col = j * block.dim_threads().x + block.thread_index().x; - for (int i = 0; i < N_READS; i++) { - totals[i] = op(shared_vals[col * N_READS + i], totals[i]); - } - } - } - } - - // Write result. - if (block.thread_index().y == 0) { - cub::StoreDirectBlocked( - column, - out + out_idx * args.reduction_stride, - totals, - args.reduction_stride); - } -} - template < typename T, typename U, @@ -152,67 +91,94 @@ template < int BM, int BN, int N_READS = 4> -__global__ void col_reduce_looped( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args) { +__global__ void +col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - constexpr int n_warps = BN / N_READS; + constexpr int threads_per_row = BN / N_READS; - int out_idx = grid.block_rank() / grid.dim_blocks().x; - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + // Compute the indices for the tile + size_t tile_idx = grid.block_rank(); + size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); + size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + // Compute the indices for the thread within the tile + short thread_x = block.thread_rank() % threads_per_row; + short thread_y = block.thread_rank() / threads_per_row; + + // Move the input pointer + in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + + tile_x * BN; + + // Initialize the running totals Op op; U totals[N_READS]; for (int i = 0; i < N_READS; i++) { totals[i] = ReduceInit::value(); } - // Read input to local. - int r = block.thread_rank() / n_warps; - int column = block.thread_rank() % n_warps; - int in_offset = grid.block_index().x * BN; LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); - for (; r < args.non_col_reductions * args.reduction_size; r += BM) { - U vals[N_READS]; - cub::LoadDirectBlocked( - column, - make_cast_iterator(in + loop.location() + in_offset), - vals, - args.reduction_stride - in_offset, - ReduceInit::value()); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); + loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data()); + size_t total = args.non_col_reductions * args.reduction_size; + if (tile_x * BN + BN <= args.reduction_stride) { + if (args.reduction_stride % N_READS == 0) { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked( + thread_x, + in + loop.location(), + vals, + args.reduction_stride - tile_x * BN, + __cast(ReduceInit::value())); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } // Do warp reduce for each output. - constexpr int n_outputs = BN / n_warps; + constexpr int n_outputs = BN / threads_per_row; static_assert(BM == 32 && n_outputs == N_READS); __shared__ U shared_vals[BM * BN]; - size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; + short s_idx = thread_y * BN + thread_x * N_READS; for (int i = 0; i < N_READS; i++) { - shared_vals[col + i] = totals[i]; + shared_vals[s_idx + i] = totals[i]; } block.sync(); - col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; for (int i = 0; i < n_outputs; i++) { - totals[i] = cg::reduce(warp, shared_vals[col + i], op); + totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); } // Write result. if (warp.thread_rank() == 0) { - size_t out_offset = grid.block_index().x * BN; cub::StoreDirectBlocked( warp.meta_group_rank(), - out + out_idx * args.reduction_stride + out_offset, + out + tile_y * args.reduction_stride + tile_x * BN, totals, - args.reduction_stride - out_offset); + args.reduction_stride - tile_x * BN); } } @@ -220,14 +186,55 @@ __global__ void col_reduce_looped( inline auto output_grid_for_col_reduce( const array& out, - const cu::ColReduceArgs& args) { - auto out_shape = out.shape(); - auto out_strides = out.strides(); - while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { - out_shape.pop_back(); - out_strides.pop_back(); + const cu::ColReduceArgs& args, + int bn) { + int gx, gy = 1; + size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); + size_t n_outer_blocks = out.size() / args.reduction_stride; + size_t n_blocks = n_outer_blocks * n_inner_blocks; + while (n_blocks / gy > INT32_MAX) { + gy *= 2; } - return get_2d_grid_dims(out_shape, out_strides); + gx = cuda::ceil_div(n_blocks, gy); + + return dim3(gx, gy, 1); +} + +void col_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::ColReduceArgs args) { + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN); + int blocks = BM * BN / N_READS; + auto kernel = cu::col_reduce_looped; + kernel<<>>(indata, out.data(), args); + }); + }); + }); + }); } void col_reduce( @@ -237,42 +244,23 @@ void col_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { + // Current col reduce options + // + // - col_reduce_looped + // + // It is a general strided reduce. Each threadblock computes the output for + // a subrow of the fast moving axis. For instance 32 elements. + // + // Notes: As in row reduce we opt to read as much in order as possible and + // leave transpositions as they are (contrary to our Metal backend). + // + // Moreover we need different kernels for short rows and tuning + + // Make the args struct to help route to the best kernel cu::ColReduceArgs args(in, plan, axes); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - using InType = cuda_type_t; - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using OutType = cu::ReduceResult::type; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - constexpr int N_READS = 4; - dim3 block_dims; - dim3 num_blocks = output_grid_for_col_reduce(out, args); - num_blocks.z = num_blocks.y; - num_blocks.y = num_blocks.x; - auto kernel = - cu::col_reduce_small; - size_t total = args.non_col_reductions * args.reduction_size; - if (total < 32) { - size_t stride_blocks = - cuda::ceil_div(args.reduction_stride, N_READS); - block_dims.x = std::min(stride_blocks, 32ul); - block_dims.y = std::min(total, 8ul); - num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x); - } else { - constexpr int BM = 32; - constexpr int BN = 32; - block_dims.x = BM * BN / N_READS; - num_blocks.x = cuda::ceil_div(args.reduction_stride, BN); - kernel = cu:: - col_reduce_looped; - } - kernel<<>>( - in.data(), out.data(), args); - }); - }); - }); - }); + // Fallback col reduce + col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu new file mode 100644 index 000000000..50fe109c4 --- /dev/null +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -0,0 +1,50 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void init_reduce(U* out, size_t size) { + auto index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = ReduceInit::value(); + } +} + +} // namespace cu + +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + // Allocate if needed + if (out.data_shared_ptr() == nullptr) { + out.set_data(allocator::malloc(out.nbytes())); + } + + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::init_reduce; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); + grid.x = (grid.x + 1023) / 1024; + kernel<<>>(out.data(), out.size()); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index a673e052e..a7262bcc2 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -47,13 +47,11 @@ namespace mlx::core { throw std::invalid_argument("Unknown reduce type."); \ } -void segmented_reduce( +void all_reduce( cu::CommandEncoder& encoder, const array& in, array& out, - Reduce::ReduceType reduce_type, - const std::vector& axes, - const ReductionPlan& plan); + Reduce::ReduceType reduce_type); void row_reduce( cu::CommandEncoder& encoder, @@ -71,4 +69,10 @@ void col_reduce( const std::vector& axes, const ReductionPlan& plan); +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 832787222..b40d2bd4e 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -3,48 +3,89 @@ #pragma once #include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_utils.cuh" namespace mlx::core::cu { // Reduce ops. struct And { - __device__ bool operator()(bool a, bool b) { + __device__ __forceinline__ bool operator()(bool a, bool b) { return a && b; } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } }; struct Or { - __device__ bool operator()(bool a, bool b) { + __device__ __forceinline__ bool operator()(bool a, bool b) { return a || b; } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } }; struct Sum { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a + b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } + + __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { + atomicAdd(x, y); + } + + __device__ void atomic_update(int* x, int y) { + atomicAdd(x, y); + } + + __device__ void atomic_update(float* x, float y) { + atomicAdd(x, y); + } }; struct Prod { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a * b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; struct Min { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a < b ? a : b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; struct Max { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a > b ? a : b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; // Traits to get the result type of reduce op. @@ -120,7 +161,7 @@ template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (cuda::std::is_same_v) { - return T{1, 1}; + return T{1, 0}; } else { return typename ReduceResult::type{1}; } diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh new file mode 100644 index 000000000..d4670503a --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -0,0 +1,158 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/backend/cuda/device/utils.cuh" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct uint_by_size; +template <> +struct uint_by_size<2> { + using type = uint16_t; +}; +template <> +struct uint_by_size<4> { + using type = uint32_t; +}; +template <> +struct uint_by_size<8> { + using type = unsigned long long int; +}; + +template +__device__ void atomic_reduce(T* x, T y) { + if constexpr (sizeof(T) == 1) { + using U = uint16_t; + U* x_int = (U*)((char*)x - ((size_t)x % 2)); + int shift = ((char*)x - (char*)x_int) * 8; + int mask = 0xff << shift; + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(static_cast((old_val >> shift) & 0xff), y); + new_val = (old_val & ~mask) | (result << shift); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } else { + using U = typename uint_by_size::type; + U* x_int = (U*)(x); + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(*((T*)&old_val), y); + new_val = *((U*)&result); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } +} + +// TODO: Should make a custom complex type +template +inline __device__ U __cast(T x) { + return static_cast(x); +} + +template <> +inline __device__ bool __cast(cuComplex x) { + return x.x != 0 && x.y != 0; +} + +template <> +inline __device__ cuComplex __cast(bool x) { + return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); +} + +template +inline __device__ void +block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { + // First reduce in the current warp + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + + // Reduce across warps + if (warp.meta_group_size() > 1) { + if (warp.thread_rank() == 0) { + for (int i = 0; i < N; i++) { + smem[warp.meta_group_rank() * N + i] = vals[i]; + } + } + block.sync(); + if (warp.thread_rank() < warp.meta_group_size()) { + for (int i = 0; i < N; i++) { + vals[i] = smem[warp.thread_rank() * N + i]; + } + } else { + for (int i = 0; i < N; i++) { + vals[i] = init; + } + } + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + } +} + +} // namespace cu + +inline void allocate_same_layout( + array& out, + const array& in, + const std::vector& axes) { + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + + if (out.ndim() < in.ndim()) { + throw std::runtime_error( + "Reduction without keepdims only supported for row-contiguous inputs"); + } + + // Calculate the transpositions applied to in in order to apply them to out. + std::vector axis_order(in.ndim()); + std::iota(axis_order.begin(), axis_order.end(), 0); + std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) { + return in.strides(left) > in.strides(right); + }); + + // Transpose the shape and calculate the strides + Shape out_shape(in.ndim()); + Strides out_strides(in.ndim(), 1); + for (int i = 0; i < in.ndim(); i++) { + out_shape[i] = out.shape(axis_order[i]); + } + for (int i = in.ndim() - 2; i >= 0; i--) { + out_strides[i] = out_shape[i + 1] * out_strides[i + 1]; + } + + // Reverse the axis order to get the final strides + Strides final_strides(in.ndim()); + for (int i = 0; i < in.ndim(); i++) { + final_strides[axis_order[i]] = out_strides[i]; + } + + // Calculate the resulting contiguity and do the memory allocation + auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides); + auto fl = in.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = true; + out.set_data( + allocator::malloc(out.nbytes()), + data_size, + final_strides, + fl, + allocator::free); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index ae54a27d6..6a8a35311 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" @@ -55,84 +57,108 @@ struct RowReduceArgs { non_row_reductions *= reduce_shape[i]; } } + + // Convert shape and strides as if in was contiguous + void sort_access_pattern(const array& in, const std::vector& axes) { + auto shape_vec = in.shape(); + auto strides_vec = in.strides(); + std::tie(shape_vec, strides_vec) = + shapes_without_reduction_axes(shape_vec, strides_vec, axes); + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + decltype(shape_vec) sorted_shape; + decltype(strides_vec) sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + } }; -template -__global__ void row_reduce_small( - const T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { - size_t out_idx = cg::this_grid().thread_rank(); - if (out_idx >= out_size) { - return; - } - - Op op; - - U total_val = ReduceInit::value(); - LoopedElemToLoc 2)> loop(args.reduce_ndim); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - for (size_t n = 0; n < args.non_row_reductions; n++) { - for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); - } - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - - out[out_idx] = total_val; -} - -template -__global__ void row_reduce_small_warp( - const T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { +template +__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - size_t out_idx = grid.thread_rank() / WARP_SIZE; - if (out_idx >= out_size) { - return; + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[M][N]; + U accs[M]; + for (int i = 0; i < M; i++) { + accs[i] = init; } - Op op; + const size_t start_row = + min(n_rows - M, static_cast(grid.block_rank() * M)); + const size_t full_blocks = size / (block.size() * N); + const size_t final_offset = full_blocks * (block.size() * N); + in += start_row * size; + out += start_row; - U total_val = ReduceInit::value(); - LoopedElemToLoc 2)> loop(args.reduce_ndim); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - for (size_t n = warp.thread_rank(); n < args.non_row_reductions; - n += WARP_SIZE) { - for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); + if (size % N == 0) { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } + } + } else { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } } - loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data()); } - total_val = cg::reduce(warp, total_val, op); + if (final_offset < size) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + final_offset, + vals[k], + size, + __cast(init)); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } + } - if (warp.thread_rank() == 0) { - out[out_idx] = total_val; + __shared__ U shared_accumulators[32 * M]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + if (grid.block_rank() * M + M <= n_rows) { + for (int i = 0; i < M; i++) { + out[i] = accs[i]; + } + } else { + short offset = grid.block_rank() * M + M - n_rows; + for (int i = offset; i < M; i++) { + out[i] = accs[i]; + } + } } } @@ -141,55 +167,165 @@ template < typename U, typename Op, int NDIM, - int BLOCK_DIM_X, + int BLOCK_DIM, int N_READS = 4> __global__ void row_reduce_looped( - const T* in, + T* in, U* out, size_t out_size, const __grid_constant__ RowReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); - size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; - if (out_idx >= out_size) { - return; - } + size_t out_idx = grid.block_rank(); Op op; - U total_val = ReduceInit::value(); + U total[1]; + U init = ReduceInit::value(); + total[0] = init; LoopedElemToLoc 2)> loop(args.reduce_ndim); + size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); + size_t final_offset = full_blocks * BLOCK_DIM * N_READS; in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); for (size_t n = 0; n < args.non_row_reductions; n++) { - for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS); - r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r * BLOCK_DIM_X + block.thread_index().x, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); + for (size_t r = 0; r < full_blocks; r++) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + loop.location() + r * BLOCK_DIM * N_READS, + vals); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], __cast(vals[i])); + } } + if (final_offset < args.row_size) { + T vals[N_READS]; + cub::LoadDirectBlocked( + block.thread_rank(), + in + loop.location() + final_offset, + vals, + args.row_size - final_offset, + __cast(init)); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], __cast(vals[i])); + } + } + // TODO: Maybe block.sync() here? loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } - typedef cub::BlockReduce BlockReduceT; - __shared__ typename BlockReduceT::TempStorage temp; - - total_val = BlockReduceT(temp).Reduce(total_val, op); + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, total, shared_accumulators, op, init); if (block.thread_rank() == 0) { - out[out_idx] = total_val; + out[out_idx] = total[0]; } } } // namespace cu +void row_reduce_simple( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + constexpr int N_READS = 8; + + // Allocate data for the output using in's layout to avoid elem_to_loc in the + // kernel. + allocate_same_layout(out, in, axes); + + // TODO: If out.size() < 1024 which will be a common case then write this in + // 2 passes. Something like 32 * out.size() and then do a warp reduce. + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + // Calculate the grid and block dims + size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_simple; + if (grid.x >= 1024) { + grid.x = (grid.x + 1) / 2; + kernel = cu::row_reduce_simple; + } + + // Launch + kernel<<>>( + indata, out.data(), out.size(), plan.shape.back()); + }); + }); + }); +} + +void row_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::RowReduceArgs args) { + constexpr int N_READS = 8; + + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + // Calculate the grid and block dims + args.sort_access_pattern(in, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + size_t reductions = (args.row_size + N_READS - 1) / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_looped; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + MLX_SWITCH_BLOCK_DIM(threads, THREADS, { + kernel = cu::row_reduce_looped; + block.x = THREADS; + }); + }); + + // Launch + kernel<<>>( + indata, out.data(), out.size(), args); + }); + }); + }); +} + void row_reduce( cu::CommandEncoder& encoder, const array& in, @@ -197,54 +333,35 @@ void row_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { + // Current row reduction options + // + // - row_reduce_simple + // + // That means that we are simply reducing across the fastest moving axis. + // We are reducing 1 or 2 rows per threadblock depending on the size of + // output. + // + // - row_reduce_looped + // + // It is a general row reduction. We are computing 1 output per + // threadblock. We read the fastest moving axis vectorized and loop over + // the rest of the axes. + // + // Notes: We opt to read as much in order as possible and leave + // transpositions as they are (contrary to our Metal backend). + + // Simple row reduce means that we have 1 axis that we are reducing over and + // it has stride 1. + if (plan.shape.size() == 1) { + row_reduce_simple(encoder, in, out, reduce_type, axes, plan); + return; + } + + // Make the args struct to help route to the best kernel cu::RowReduceArgs args(in, plan, axes); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - using InType = cuda_type_t; - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using OutType = cu::ReduceResult::type; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - constexpr size_t N_READS = 4; - dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims, num_blocks; - auto kernel = - cu::row_reduce_small; - if (args.row_size <= 64) { - if ((args.non_row_reductions < 32 && args.row_size <= 8) || - (args.non_row_reductions <= 8)) { - block_dims.x = std::min(out_dims.x, 1024u); - num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); - num_blocks.y = out_dims.y; - } else { - block_dims.x = WARP_SIZE; - num_blocks.y = out_dims.x; - num_blocks.z = out_dims.y; - kernel = - cu::row_reduce_small_warp; - } - } else { - size_t num_threads = cuda::ceil_div(args.row_size, N_READS); - num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; - MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { - num_blocks.y = out_dims.x; - num_blocks.z = out_dims.y; - block_dims.x = BLOCK_DIM_X; - kernel = cu::row_reduce_looped< - InType, - OutType, - OP, - NDIM, - BLOCK_DIM_X, - N_READS>; - }); - } - kernel<<>>( - in.data(), out.data(), out.size(), args); - }); - }); - }); - }); + // Fallback row reduce + row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu deleted file mode 100644 index 114d71809..000000000 --- a/mlx/backend/cuda/reduce/segmented_reduce.cu +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cast_op.cuh" -#include "mlx/backend/cuda/reduce/reduce.cuh" - -#include -#include -#include - -namespace mlx::core { - -template -void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); -} - -template -void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR( - cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); -} - -struct MultiplyOp { - int factor; - __device__ int operator()(int i) { - return i * factor; - } -}; - -void segmented_reduce( - cu::CommandEncoder& encoder, - const array& in, - array& out, - Reduce::ReduceType reduce_type, - const std::vector& axes, - const ReductionPlan& plan) { - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using InType = cuda_type_t; - using OutType = cu::ReduceResult::type; - auto in_iter = cu::make_cast_iterator( - thrust::device_pointer_cast(in.data())); - auto out_ptr = thrust::device_pointer_cast(out.data()); - auto init = cu::ReduceInit::value(); - - if (plan.type == ContiguousAllReduce) { - cub_all_reduce( - encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); - } else if (plan.type == ContiguousReduce) { - auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); - cub_segmented_reduce( - encoder, - in_iter, - out_ptr, - out.size(), - offsets, - offsets + 1, - OP(), - init, - stream); - } else { - throw std::runtime_error("Unsupported plan in segmented_reduce."); - } - }); - }); - }); -} - -} // namespace mlx::core diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index fc001ae75..652e6da19 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { make_cast_iterator(in), vals, axis_size, - Limits::finite_min()); + Limits::min()); prevmax = maxval; maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); // Online normalizer calculation for softmax: @@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { block.sync(); maxval = warp.thread_rank() < warp.meta_group_size() ? local_max[warp.thread_rank()] - : Limits::finite_min(); + : Limits::min(); maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); if (warp.thread_rank() == 0) { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index bcb95dbb7..cba642ca1 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,7 +1,6 @@ cuda_skip = { "TestArray.test_api", "TestBF16.test_arg_reduction_ops", - "TestBF16.test_reduction_ops", "TestBlas.test_complex_gemm", "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", @@ -13,11 +12,7 @@ cuda_skip = { "TestLayers.test_upsample", "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", - "TestOps.test_softmax", - "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", - "TestReduce.test_expand_sums", - "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", # Block masked matmul NYI "TestBlas.test_block_masked_matmul",