From a04436b6292e6554ea3cb020e1bad3365ef4d81d Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 17 Apr 2025 09:13:00 +0000 Subject: [PATCH] CUDA backend: reduce ops --- mlx/backend/cuda/CMakeLists.txt | 4 + mlx/backend/cuda/kernels/fp16_math.cuh | 53 ++++ mlx/backend/cuda/kernels/reduce_ops.cuh | 124 +++++++++ mlx/backend/cuda/kernels/utils.cuh | 155 +++++++++++ mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/reduce.cu | 82 ++++++ mlx/backend/cuda/reduce/col_reduce.cu | 279 ++++++++++++++++++++ mlx/backend/cuda/reduce/reduce.cuh | 74 ++++++ mlx/backend/cuda/reduce/row_reduce.cu | 251 ++++++++++++++++++ mlx/backend/cuda/reduce/segmented_reduce.cu | 86 ++++++ mlx/backend/cuda/utils.cpp | 23 ++ mlx/backend/cuda/utils.h | 5 + 12 files changed, 1136 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/kernels/reduce_ops.cuh create mode 100644 mlx/backend/cuda/reduce.cu create mode 100644 mlx/backend/cuda/reduce/col_reduce.cu create mode 100644 mlx/backend/cuda/reduce/reduce.cuh create mode 100644 mlx/backend/cuda/reduce/row_reduce.cu create mode 100644 mlx/backend/cuda/reduce/segmented_reduce.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index a077baed0..a289f93b8 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -13,6 +13,10 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index f6fa17bb9..433b32d29 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -2,6 +2,7 @@ #pragma once +#include #include #include #include @@ -9,6 +10,58 @@ namespace mlx::core::cu { +/////////////////////////////////////////////////////////////////////////////// +// Constant values for half types. +/////////////////////////////////////////////////////////////////////////////// + +#define MLX_DEFINE_CONSTEXPR_VALUE(NAME, HALF_VALUE, BF16_VALUE, ...) \ + template \ + constexpr __host__ __device__ T NAME() { \ + if constexpr (cuda::std::is_same_v) { \ + uint16_t value = HALF_VALUE; \ + return __builtin_bit_cast(__half, value); \ + } else if constexpr (cuda::std::is_same_v) { \ + uint16_t value = BF16_VALUE; \ + return __builtin_bit_cast(__nv_bfloat16, value); \ + } else { \ + __VA_ARGS__ \ + } \ + } + +MLX_DEFINE_CONSTEXPR_VALUE(zero_value, 0x0000, 0x0000, { + if constexpr (cuda::std::is_same_v) { + return cuComplex{0, 0}; + } else { + return 0; + } +}) + +MLX_DEFINE_CONSTEXPR_VALUE(one_value, 0x3C00, 0x3F80, { + if constexpr (cuda::std::is_same_v) { + return cuComplex{1, 1}; + } else { + return 1; + } +}) + +MLX_DEFINE_CONSTEXPR_VALUE(infinite_value, 0x7C00, 0x7F80, { + return cuda::std::numeric_limits::infinity(); +}) + +MLX_DEFINE_CONSTEXPR_VALUE(negative_infinite_value, 0xFC00, 0xFF80, { + return -cuda::std::numeric_limits::infinity(); +}) + +MLX_DEFINE_CONSTEXPR_VALUE(max_value, 0x7BFF, 0x7F7F, { + return cuda::std::numeric_limits::max(); +}) + +MLX_DEFINE_CONSTEXPR_VALUE(lowest_value, 0xFBFF, 0xFF7F, { + return cuda::std::numeric_limits::lowest(); +}) + +#undef MLX_DEFINE_CONSTEXPR_VALUE + /////////////////////////////////////////////////////////////////////////////// // Unary ops for half types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/kernels/reduce_ops.cuh b/mlx/backend/cuda/kernels/reduce_ops.cuh new file mode 100644 index 000000000..8269f9a8e --- /dev/null +++ b/mlx/backend/cuda/kernels/reduce_ops.cuh @@ -0,0 +1,124 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/kernels/utils.cuh" + +namespace mlx::core::cu { + +// Reduce ops. +struct And { + __device__ bool operator()(bool a, bool b) { + return a && b; + } +}; + +struct Or { + __device__ bool operator()(bool a, bool b) { + return a || b; + } +}; + +struct Sum { + template + __device__ T operator()(T a, T b) { + return a + b; + } +}; + +struct Prod { + template + __device__ T operator()(T a, T b) { + return a * b; + } +}; + +struct Min { + template + __device__ T operator()(T a, T b) { + return a < b ? a : b; + } +}; + +struct Max { + template + __device__ T operator()(T a, T b) { + return a > b ? a : b; + } +}; + +// Traits to get the result type of reduce op. +template +struct ReduceResult; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = T; +}; + +template +struct ReduceResult { + using type = T; +}; + +// Traits to get the init value of reduce op. +template +struct ReduceInit; + +template +struct ReduceInit { + static constexpr bool value = true; +}; + +template +struct ReduceInit { + static constexpr bool value = false; +}; + +template +struct ReduceInit { + static constexpr typename ReduceResult::type value = zero_value(); +}; + +template +struct ReduceInit { + static constexpr typename ReduceResult::type value = one_value(); +}; + +template +struct ReduceInit { + static constexpr T value = Limits::max; +}; + +template +struct ReduceInit { + static constexpr T value = Limits::min; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh index 4d69b7356..6f77138e9 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -8,6 +8,8 @@ #pragma once +#include "mlx/backend/cuda/kernels/fp16_math.cuh" + #include #include #include @@ -25,6 +27,55 @@ namespace mlx::core::cu { using Shape = cuda::std::array; using Strides = cuda::std::array; +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static constexpr U max = cuda::std::numeric_limits::max(); + static constexpr U min = cuda::std::numeric_limits::min(); + static constexpr U finite_max = cuda::std::numeric_limits::max(); + static constexpr U finite_min = cuda::std::numeric_limits::min(); +}; + +template <> +struct Limits { + static constexpr bool max = true; + static constexpr bool min = false; +}; + +template <> +struct Limits { + static constexpr cuComplex max = { + cuda::std::numeric_limits::infinity(), + cuda::std::numeric_limits::infinity()}; + static constexpr cuComplex min = { + -cuda::std::numeric_limits::infinity(), + -cuda::std::numeric_limits::infinity()}; +}; + +// Like MLX_FORALL_FLOAT_TYPES but use CUDA types. +#define MLX_FORALL_CUDA_FLOAT_TYPES(_) \ + _(float, float32) \ + _(double, float64) \ + _(__half, float16) \ + _(__nv_bfloat16, bfloat16) + +// Some CCCL/CUDA combinations do not provide constexpr limits for half types. +#define SPECIALIZE_FloatLimits(CPP_TYPE, DTYPE) \ + template <> \ + struct Limits { \ + static constexpr CPP_TYPE max = infinite_value(); \ + static constexpr CPP_TYPE min = negative_infinite_value(); \ + static constexpr CPP_TYPE finite_max = max_value(); \ + static constexpr CPP_TYPE finite_min = lowest_value(); \ + }; + +MLX_FORALL_CUDA_FLOAT_TYPES(SPECIALIZE_FloatLimits) + +#undef SPECIALIZE_FloatLimits + /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// @@ -40,4 +91,108 @@ elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { return loc; } +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + + __device__ void next(const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index += n; + offset += n * OffsetT(strides[dim - 1]); + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } + index = 0; + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, true, OffsetT> { + int dim; + OffsetT offset{0}; + uint index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim) {} + + __device__ void next(const int* shape, const int64_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, false, OffsetT> { + OffsetT offset{0}; + + __device__ LoopedElemToLoc(int) {} + + __device__ void next(const int*, const int64_t* strides) { + offset += OffsetT(strides[0]); + } + + __device__ void next(int n, const int*, const int64_t* strides) { + offset += n * OffsetT(strides[0]); + } + + __device__ OffsetT location() { + return offset; + } +}; + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 31f393bfa..12a1746a0 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -78,7 +78,6 @@ NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) -NO_GPU(Reduce) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu new file mode 100644 index 000000000..d32180d77 --- /dev/null +++ b/mlx/backend/cuda/reduce.cu @@ -0,0 +1,82 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Reduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Reduce::eval_gpu"); + assert(inputs.size() == 1); + array in = inputs[0]; + + // Make sure no identity reductions trickle down here. + assert(!axes_.empty()); + assert(out.size() != in.size()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + // Fill out with init value. + if (in.size() == 0) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + thrust::fill_n( + cu::thrust_policy(stream), + thrust::device_pointer_cast(out.data()), + out.data_size(), + cu::ReduceInit::value); + }); + }); + }); + return; + } + + // Reduce. + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + if (plan.type == GeneralReduce) { + array in_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, in_copy, CopyType::General, s); + encoder.add_temporary(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + if ((plan.type == ContiguousAllReduce) || + (plan.type == ContiguousReduce && plan.shape.size() == 1)) { + segmented_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + col_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + throw std::runtime_error("No plan reached in reduce."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu new file mode 100644 index 000000000..6653d1343 --- /dev/null +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -0,0 +1,279 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/cast_iterator.cuh" +#include "mlx/backend/cuda/kernels/reduce_ops.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size(); + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +template +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + int column = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + if (column * N_READS >= args.reduction_stride) { + return; + } + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value; + } + + // Read input to local. + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next( + block.thread_index().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + for (size_t r = block.thread_index().y; + r < args.non_col_reductions * args.reduction_size; + r += block.dim_threads().y) { + U vals[N_READS]; + cub::LoadDirectBlocked( + column, + make_cast_iterator(in + loop.location()), + vals, + args.reduction_stride, + ReduceInit::value); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next( + block.dim_threads().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + } + + // Do block reduce when each column has more than 1 element to reduce. + if (block.dim_threads().y > 1) { + __shared__ U shared_vals[32 * 8 * N_READS]; + size_t col = + block.thread_index().y * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + shared_vals[col * N_READS + i] = totals[i]; + } + block.sync(); + if (block.thread_index().y == 0) { + for (int i = 0; i < N_READS; i++) { + totals[i] = shared_vals[block.thread_index().x * N_READS + i]; + } + for (int j = 1; j < block.dim_threads().y; j++) { + col = j * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + totals[i] = op(shared_vals[col * N_READS + i], totals[i]); + } + } + } + } + + // Write result. + if (block.thread_index().y == 0) { + cub::StoreDirectBlocked( + column, + out + out_idx * args.reduction_stride, + totals, + args.reduction_stride); + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4> +__global__ void col_reduce_looped( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int n_warps = BN / N_READS; + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value; + } + + // Read input to local. + int r = block.thread_rank() / n_warps; + int column = block.thread_rank() % n_warps; + int in_offset = grid.block_index().x * BN; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); + for (; r < args.non_col_reductions * args.reduction_size; r += BM) { + U vals[N_READS]; + cub::LoadDirectBlocked( + column, + make_cast_iterator(in + loop.location() + in_offset), + vals, + args.reduction_stride - in_offset, + ReduceInit::value); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / n_warps; + static_assert(BM == 32 && n_outputs == N_READS); + __shared__ U shared_vals[BM * BN]; + size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[col + i] = totals[i]; + } + block.sync(); + col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + for (int i = 0; i < n_outputs; i++) { + totals[i] = cg::reduce(warp, shared_vals[col + i], op); + } + + // Write result. + if (warp.thread_rank() == 0) { + size_t out_offset = grid.block_index().x * BN; + cub::StoreDirectBlocked( + warp.meta_group_rank(), + out + out_idx * args.reduction_stride + out_offset, + totals, + args.reduction_stride - out_offset); + } +} + +} // namespace cu + +inline auto output_grid_for_col_reduce( + const array& out, + const cu::ColReduceArgs& args) { + auto out_shape = out.shape(); + auto out_strides = out.strides(); + while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { + out_shape.pop_back(); + out_strides.pop_back(); + } + return get_2d_grid_dims(out_shape, out_strides); +} + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + cu::ColReduceArgs args(in, plan, axes); + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + using InType = cuda_type_t; + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using OutType = cu::ReduceResult::type; + MLX_SWITCH_NDIM(args.reduce_ndim, NDIM, { + constexpr int N_READS = 4; + dim3 block_dims; + dim3 num_blocks = output_grid_for_col_reduce(out, args); + num_blocks.z = num_blocks.y; + num_blocks.y = num_blocks.x; + auto kernel = + cu::col_reduce_small; + size_t total = args.non_col_reductions * args.reduction_size; + if (total < 32) { + size_t stride_blocks = + cuda::ceil_div(args.reduction_stride, N_READS); + block_dims.x = std::min(stride_blocks, 32ul); + block_dims.y = std::min(total, 8ul); + num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x); + } else { + constexpr int BM = 32; + constexpr int BN = 32; + block_dims.x = BM * BN / N_READS; + num_blocks.x = cuda::ceil_div(args.reduction_stride, BN); + kernel = cu:: + col_reduce_looped; + } + kernel<<>>( + in.data(), out.data(), args); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh new file mode 100644 index 000000000..b9071c594 --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -0,0 +1,74 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/backend/cuda/kernels/reduce_ops.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +// Dispatch dynamic ndim to constexpr. +#define MORE_THAN_TWO 5 +#define MLX_SWITCH_NDIM(ndim, NDIM, ...) \ + if (ndim == 1) { \ + constexpr uint32_t NDIM = 1; \ + __VA_ARGS__; \ + } else if (ndim == 2) { \ + constexpr uint32_t NDIM = 2; \ + __VA_ARGS__; \ + } else { \ + constexpr uint32_t NDIM = MORE_THAN_TWO; \ + __VA_ARGS__; \ + } + +// Dispatch reduce ops to constexpr. +#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ + if (REDUCE == Reduce::ReduceType::And) { \ + using OP = cu::And; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Or) { \ + using OP = cu::Or; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Sum) { \ + using OP = cu::Sum; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Prod) { \ + using OP = cu::Prod; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Max) { \ + using OP = cu::Max; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Min) { \ + using OP = cu::Min; \ + __VA_ARGS__; \ + } else { \ + throw std::invalid_argument("Unknown reduce type."); \ + } + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu new file mode 100644 index 000000000..7a2d8e91b --- /dev/null +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -0,0 +1,251 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/cast_iterator.cuh" +#include "mlx/backend/cuda/kernels/reduce_ops.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct RowReduceArgs { + // The size of the row being reduced, i.e. the size of last dimension. + int row_size; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes excluding last dimension. + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of rows we are reducing. Namely prod(reduce_shape). + size_t non_row_reductions; + + RowReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + row_size = plan.shape.back(); + + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size() - 1; + + non_row_reductions = 1; + for (int i = 0; i < reduce_ndim; i++) { + non_row_reductions *= reduce_shape[i]; + } + } +}; + +template +__global__ void row_reduce_small( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + size_t out_idx = cg::this_grid().thread_rank(); + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + out[out_idx] = total_val; +} + +template +__global__ void row_reduce_small_warp( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t out_idx = grid.thread_rank() / WARP_SIZE; + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = warp.thread_rank(); n < args.non_row_reductions; + n += WARP_SIZE) { + for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data()); + } + + total_val = cg::reduce(warp, total_val, op); + + if (warp.thread_rank() == 0) { + out[out_idx] = total_val; + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BLOCK_DIM_X, + int N_READS = 4> +__global__ void row_reduce_looped( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS); + r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM_X + block.thread_index().x, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + total_val = BlockReduceT(temp).Reduce(total_val, op); + + if (block.thread_rank() == 0) { + out[out_idx] = total_val; + } +} + +} // namespace cu + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + cu::RowReduceArgs args(in, plan, axes); + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + using InType = cuda_type_t; + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using OutType = cu::ReduceResult::type; + MLX_SWITCH_NDIM(args.reduce_ndim, NDIM, { + constexpr size_t N_READS = 4; + dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block_dims, num_blocks; + auto kernel = + cu::row_reduce_small; + if (args.row_size <= 64) { + if ((args.non_row_reductions < 32 && args.row_size <= 8) || + (args.non_row_reductions <= 8)) { + block_dims.x = std::min(out_dims.x, 1024u); + num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); + num_blocks.y = out_dims.y; + } else { + block_dims.x = WARP_SIZE; + num_blocks.y = out_dims.x; + num_blocks.z = out_dims.y; + kernel = + cu::row_reduce_small_warp; + } + } else { + size_t num_threads = cuda::ceil_div(args.row_size, N_READS); + num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; + MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { + num_blocks.y = out_dims.x; + num_blocks.z = out_dims.y; + block_dims.x = BLOCK_DIM_X; + kernel = cu::row_reduce_looped< + InType, + OutType, + OP, + NDIM, + BLOCK_DIM_X, + N_READS>; + }); + } + kernel<<>>( + in.data(), out.data(), out.size(), args); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu new file mode 100644 index 000000000..07aac4748 --- /dev/null +++ b/mlx/backend/cuda/reduce/segmented_reduce.cu @@ -0,0 +1,86 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/cast_iterator.cuh" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +template +void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); +} + +template +void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR( + cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); +} + +struct MultiplyOp { + int factor; + + __device__ int operator()(int i) { + return i * factor; + } +}; + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + auto in_iter = cu::make_cast_iterator( + thrust::device_pointer_cast(in.data())); + auto out_ptr = thrust::device_pointer_cast(out.data()); + auto init = cu::ReduceInit::value; + + if (plan.type == ContiguousAllReduce) { + cub_all_reduce( + encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); + } else if (plan.type == ContiguousReduce) { + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); + cub_segmented_reduce( + encoder, + in_iter, + out_ptr, + out.size(), + offsets, + offsets + 1, + OP(), + init, + stream); + } else { + throw std::runtime_error("Unsupported plan in segmented_reduce."); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 2a11a518e..093a88a90 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -23,4 +23,27 @@ void check_cuda_error(const char* name, cudaError_t err) { } } +// TODO: The implementation is identical to meta/utils.cpp +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + return dim3{static_cast(grid_x), static_cast(grid_y), 1}; +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 6eaec8984..3edf61076 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,6 +4,8 @@ #pragma once +#include "mlx/array.h" + #include namespace mlx::core { @@ -35,4 +37,7 @@ void check_cuda_error(const char* name, cudaError_t err); // The macro version that prints the command that failed. #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) +// Computes a 2D grid where each element is < UINT_MAX. +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); + } // namespace mlx::core