From c9fa68664a36ca6d8f071fe2155fd136775bc87e Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 03:22:25 +0900 Subject: [PATCH] CUDA backend: reduce (#2269) --- mlx/backend/cuda/CMakeLists.txt | 4 + mlx/backend/cuda/kernel_utils.cuh | 25 ++ mlx/backend/cuda/kernels/utils.cuh | 198 ++++++++++++++ mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/reduce.cu | 82 ++++++ mlx/backend/cuda/reduce/col_reduce.cu | 278 ++++++++++++++++++++ mlx/backend/cuda/reduce/reduce.cuh | 74 ++++++ mlx/backend/cuda/reduce/reduce_ops.cuh | 144 ++++++++++ mlx/backend/cuda/reduce/row_reduce.cu | 250 ++++++++++++++++++ mlx/backend/cuda/reduce/segmented_reduce.cu | 84 ++++++ 10 files changed, 1139 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/reduce.cu create mode 100644 mlx/backend/cuda/reduce/col_reduce.cu create mode 100644 mlx/backend/cuda/reduce/reduce.cuh create mode 100644 mlx/backend/cuda/reduce/reduce_ops.cuh create mode 100644 mlx/backend/cuda/reduce/row_reduce.cu create mode 100644 mlx/backend/cuda/reduce/segmented_reduce.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 9d9657e1f..c053b4428 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -21,6 +21,10 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index aeb065206..656ddebea 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -47,6 +47,31 @@ namespace mlx::core { __VA_ARGS__; \ } +// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2. +#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ + { \ + uint32_t _num_threads = NUM_THREADS; \ + if (_num_threads <= WARP_SIZE) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 2) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 4) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 8) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 16) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \ + __VA_ARGS__; \ + } else { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \ + __VA_ARGS__; \ + } \ + } + // Maps CPU types to CUDA types. template struct CTypeToCudaType { diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh index 16957d132..7636710dc 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -9,6 +9,8 @@ #pragma once #include +#include +#include #include #include #include @@ -19,6 +21,10 @@ namespace mlx::core::cu { // CUDA kernel utils /////////////////////////////////////////////////////////////////////////////// +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 + // To pass shape/strides to kernels via constant memory, their size must be // known at compile time. #define MAX_NDIM 8 @@ -26,6 +32,94 @@ namespace mlx::core::cu { using Shape = cuda::std::array; using Strides = cuda::std::array; +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T min() { + return cuda::std::numeric_limits::min(); + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { + return cuda::std::numeric_limits::min(); + } +}; + +template +struct Limits< + T, + cuda::std::enable_if_t< + cuda::std::is_same_v || cuda::std::is_same_v>> { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T min() { + return -cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { + return cuda::std::numeric_limits::lowest(); + } +}; + +// CUDA 11 does not have host side arithmatic operators for half types. +template +struct Limits< + T, + cuda::std::enable_if_t< + cuda::std::is_same_v || + cuda::std::is_same_v>> { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T min() { +#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 + return -cuda::std::numeric_limits::infinity(); +#else + return -cuda::std::numeric_limits::infinity(); +#endif + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { +#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 + return cuda::std::numeric_limits::lowest(); +#else + return cuda::std::numeric_limits::lowest(); +#endif + } +}; + +template <> +struct Limits { + static constexpr __host__ __device__ bool max() { + return true; + } + static constexpr __host__ __device__ bool min() { + return false; + } +}; + +template <> +struct Limits { + static constexpr __host__ __device__ cuComplex max() { + return {Limits::max(), Limits::max()}; + } + static constexpr __host__ __device__ cuComplex min() { + return {Limits::min(), Limits::min()}; + } +}; + /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// @@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( return cuda::std::make_tuple(a_loc, b_loc); } +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + + __device__ void next(const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index += n; + offset += n * OffsetT(strides[dim - 1]); + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } + index = 0; + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, true, OffsetT> { + int dim; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim) {} + + __device__ void next(const int* shape, const int64_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, false, OffsetT> { + OffsetT offset{0}; + + __device__ LoopedElemToLoc(int) {} + + __device__ void next(const int*, const int64_t* strides) { + offset += OffsetT(strides[0]); + } + + __device__ void next(int n, const int*, const int64_t* strides) { + offset += n * OffsetT(strides[0]); + } + + __device__ OffsetT location() { + return offset; + } +}; + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index caa2c33ff..1b273e959 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -91,7 +91,6 @@ NO_GPU_MULTI(LUF) NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) -NO_GPU(Reduce) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu new file mode 100644 index 000000000..a740113db --- /dev/null +++ b/mlx/backend/cuda/reduce.cu @@ -0,0 +1,82 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Reduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Reduce::eval_gpu"); + assert(inputs.size() == 1); + array in = inputs[0]; + + // Make sure no identity reductions trickle down here. + assert(!axes_.empty()); + assert(out.size() != in.size()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + // Fill out with init value. + if (in.size() == 0) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + thrust::fill_n( + cu::thrust_policy(stream), + thrust::device_pointer_cast(out.data()), + out.data_size(), + cu::ReduceInit::value()); + }); + }); + }); + return; + } + + // Reduce. + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + if (plan.type == GeneralReduce) { + array in_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, in_copy, CopyType::General, s); + encoder.add_temporary(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + if ((plan.type == ContiguousAllReduce) || + (plan.type == ContiguousReduce && plan.shape.size() == 1)) { + segmented_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + col_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + throw std::runtime_error("No plan reached in reduce."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu new file mode 100644 index 000000000..1ca50d854 --- /dev/null +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -0,0 +1,278 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size(); + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +template +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + int column = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + if (column * N_READS >= args.reduction_stride) { + return; + } + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + // Read input to local. + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next( + block.thread_index().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + for (size_t r = block.thread_index().y; + r < args.non_col_reductions * args.reduction_size; + r += block.dim_threads().y) { + U vals[N_READS]; + cub::LoadDirectBlocked( + column, + make_cast_iterator(in + loop.location()), + vals, + args.reduction_stride, + ReduceInit::value()); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next( + block.dim_threads().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + } + + // Do block reduce when each column has more than 1 element to reduce. + if (block.dim_threads().y > 1) { + __shared__ U shared_vals[32 * 8 * N_READS]; + size_t col = + block.thread_index().y * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + shared_vals[col * N_READS + i] = totals[i]; + } + block.sync(); + if (block.thread_index().y == 0) { + for (int i = 0; i < N_READS; i++) { + totals[i] = shared_vals[block.thread_index().x * N_READS + i]; + } + for (int j = 1; j < block.dim_threads().y; j++) { + col = j * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + totals[i] = op(shared_vals[col * N_READS + i], totals[i]); + } + } + } + } + + // Write result. + if (block.thread_index().y == 0) { + cub::StoreDirectBlocked( + column, + out + out_idx * args.reduction_stride, + totals, + args.reduction_stride); + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4> +__global__ void col_reduce_looped( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int n_warps = BN / N_READS; + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + // Read input to local. + int r = block.thread_rank() / n_warps; + int column = block.thread_rank() % n_warps; + int in_offset = grid.block_index().x * BN; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); + for (; r < args.non_col_reductions * args.reduction_size; r += BM) { + U vals[N_READS]; + cub::LoadDirectBlocked( + column, + make_cast_iterator(in + loop.location() + in_offset), + vals, + args.reduction_stride - in_offset, + ReduceInit::value()); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / n_warps; + static_assert(BM == 32 && n_outputs == N_READS); + __shared__ U shared_vals[BM * BN]; + size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[col + i] = totals[i]; + } + block.sync(); + col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + for (int i = 0; i < n_outputs; i++) { + totals[i] = cg::reduce(warp, shared_vals[col + i], op); + } + + // Write result. + if (warp.thread_rank() == 0) { + size_t out_offset = grid.block_index().x * BN; + cub::StoreDirectBlocked( + warp.meta_group_rank(), + out + out_idx * args.reduction_stride + out_offset, + totals, + args.reduction_stride - out_offset); + } +} + +} // namespace cu + +inline auto output_grid_for_col_reduce( + const array& out, + const cu::ColReduceArgs& args) { + auto out_shape = out.shape(); + auto out_strides = out.strides(); + while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { + out_shape.pop_back(); + out_strides.pop_back(); + } + return get_2d_grid_dims(out_shape, out_strides); +} + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + cu::ColReduceArgs args(in, plan, axes); + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + using InType = cuda_type_t; + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using OutType = cu::ReduceResult::type; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + constexpr int N_READS = 4; + dim3 block_dims; + dim3 num_blocks = output_grid_for_col_reduce(out, args); + num_blocks.z = num_blocks.y; + num_blocks.y = num_blocks.x; + auto kernel = + cu::col_reduce_small; + size_t total = args.non_col_reductions * args.reduction_size; + if (total < 32) { + size_t stride_blocks = + cuda::ceil_div(args.reduction_stride, N_READS); + block_dims.x = std::min(stride_blocks, 32ul); + block_dims.y = std::min(total, 8ul); + num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x); + } else { + constexpr int BM = 32; + constexpr int BN = 32; + block_dims.x = BM * BN / N_READS; + num_blocks.x = cuda::ceil_div(args.reduction_stride, BN); + kernel = cu:: + col_reduce_looped; + } + kernel<<>>( + in.data(), out.data(), args); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh new file mode 100644 index 000000000..0148022ab --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -0,0 +1,74 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +// Dispatch dynamic ndim to constexpr. +// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file. +#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \ + if (ndim == 1) { \ + constexpr uint32_t NDIM = 1; \ + __VA_ARGS__; \ + } else if (ndim == 2) { \ + constexpr uint32_t NDIM = 2; \ + __VA_ARGS__; \ + } else { \ + constexpr uint32_t NDIM = 5; \ + __VA_ARGS__; \ + } + +// Dispatch reduce ops to constexpr. +#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ + if (REDUCE == Reduce::ReduceType::And) { \ + using OP = cu::And; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Or) { \ + using OP = cu::Or; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Sum) { \ + using OP = cu::Sum; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Prod) { \ + using OP = cu::Prod; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Max) { \ + using OP = cu::Max; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Min) { \ + using OP = cu::Min; \ + __VA_ARGS__; \ + } else { \ + throw std::invalid_argument("Unknown reduce type."); \ + } + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh new file mode 100644 index 000000000..f06eb8541 --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -0,0 +1,144 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/kernels/utils.cuh" + +namespace mlx::core::cu { + +// Reduce ops. +struct And { + __device__ bool operator()(bool a, bool b) { + return a && b; + } +}; + +struct Or { + __device__ bool operator()(bool a, bool b) { + return a || b; + } +}; + +struct Sum { + template + __device__ T operator()(T a, T b) { + return a + b; + } +}; + +struct Prod { + template + __device__ T operator()(T a, T b) { + return a * b; + } +}; + +struct Min { + template + __device__ T operator()(T a, T b) { + return a < b ? a : b; + } +}; + +struct Max { + template + __device__ T operator()(T a, T b) { + return a > b ? a : b; + } +}; + +// Traits to get the result type of reduce op. +template +struct ReduceResult; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = T; +}; + +template +struct ReduceResult { + using type = T; +}; + +// Traits to get the init value of reduce op. +template +struct ReduceInit; + +template +struct ReduceInit { + static constexpr __host__ __device__ bool value() { + return true; + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ bool value() { + return false; + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ auto value() { + if constexpr (cuda::std::is_same_v) { + return T{0, 0}; + } else { + return typename ReduceResult::type{0}; + } + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ auto value() { + if constexpr (cuda::std::is_same_v) { + return T{1, 1}; + } else { + return typename ReduceResult::type{1}; + } + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::max(); + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::min(); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu new file mode 100644 index 000000000..3a5c4a591 --- /dev/null +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -0,0 +1,250 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct RowReduceArgs { + // The size of the row being reduced, i.e. the size of last dimension. + int row_size; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes excluding last dimension. + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of rows we are reducing. Namely prod(reduce_shape). + size_t non_row_reductions; + + RowReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + row_size = plan.shape.back(); + + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size() - 1; + + non_row_reductions = 1; + for (int i = 0; i < reduce_ndim; i++) { + non_row_reductions *= reduce_shape[i]; + } + } +}; + +template +__global__ void row_reduce_small( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + size_t out_idx = cg::this_grid().thread_rank(); + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value(); + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value()); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + out[out_idx] = total_val; +} + +template +__global__ void row_reduce_small_warp( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t out_idx = grid.thread_rank() / WARP_SIZE; + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value(); + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = warp.thread_rank(); n < args.non_row_reductions; + n += WARP_SIZE) { + for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value()); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data()); + } + + total_val = cg::reduce(warp, total_val, op); + + if (warp.thread_rank() == 0) { + out[out_idx] = total_val; + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BLOCK_DIM_X, + int N_READS = 4> +__global__ void row_reduce_looped( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value(); + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS); + r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM_X + block.thread_index().x, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value()); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + total_val = BlockReduceT(temp).Reduce(total_val, op); + + if (block.thread_rank() == 0) { + out[out_idx] = total_val; + } +} + +} // namespace cu + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + cu::RowReduceArgs args(in, plan, axes); + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + using InType = cuda_type_t; + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using OutType = cu::ReduceResult::type; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + constexpr size_t N_READS = 4; + dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block_dims, num_blocks; + auto kernel = + cu::row_reduce_small; + if (args.row_size <= 64) { + if ((args.non_row_reductions < 32 && args.row_size <= 8) || + (args.non_row_reductions <= 8)) { + block_dims.x = std::min(out_dims.x, 1024u); + num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); + num_blocks.y = out_dims.y; + } else { + block_dims.x = WARP_SIZE; + num_blocks.y = out_dims.x; + num_blocks.z = out_dims.y; + kernel = + cu::row_reduce_small_warp; + } + } else { + size_t num_threads = cuda::ceil_div(args.row_size, N_READS); + num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; + MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { + num_blocks.y = out_dims.x; + num_blocks.z = out_dims.y; + block_dims.x = BLOCK_DIM_X; + kernel = cu::row_reduce_looped< + InType, + OutType, + OP, + NDIM, + BLOCK_DIM_X, + N_READS>; + }); + } + kernel<<>>( + in.data(), out.data(), out.size(), args); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu new file mode 100644 index 000000000..563b056e4 --- /dev/null +++ b/mlx/backend/cuda/reduce/segmented_reduce.cu @@ -0,0 +1,84 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +template +void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); +} + +template +void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR( + cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); +} + +struct MultiplyOp { + int factor; + __device__ int operator()(int i) { + return i * factor; + } +}; + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + auto in_iter = cu::make_cast_iterator( + thrust::device_pointer_cast(in.data())); + auto out_ptr = thrust::device_pointer_cast(out.data()); + auto init = cu::ReduceInit::value(); + + if (plan.type == ContiguousAllReduce) { + cub_all_reduce( + encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); + } else if (plan.type == ContiguousReduce) { + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); + cub_segmented_reduce( + encoder, + in_iter, + out_ptr, + out.size(), + offsets, + offsets + 1, + OP(), + init, + stream); + } else { + throw std::runtime_error("Unsupported plan in segmented_reduce."); + } + }); + }); + }); +} + +} // namespace mlx::core