From ccf78f566ca8eae9de82b4bf64f043078db2a0a8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 05:26:17 +0900 Subject: [PATCH] CUDA backend: argreduce (#2270) --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/arg_reduce.cu | 189 ++++++++++++++++++ .../cuda/iterators/strided_iterator.cuh | 60 ++++++ mlx/backend/cuda/primitives.cu | 1 - 4 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/arg_reduce.cu create mode 100644 mlx/backend/cuda/iterators/strided_iterator.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c053b4428..ab0d5fe7c 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu new file mode 100644 index 000000000..7dbd91e46 --- /dev/null +++ b/mlx/backend/cuda/arg_reduce.cu @@ -0,0 +1,189 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct IndexValPair { + uint32_t index; + T val; +}; + +template +struct ArgMin { + constexpr __device__ T init() { + return Limits::max(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val > current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] < best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +struct ArgMax { + constexpr __device__ T init() { + return Limits::min(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val < current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] > best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +__global__ void arg_reduce_general( + const T* in, + uint32_t* out, + size_t size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides in_strides, + const __grid_constant__ Strides out_strides, + int32_t ndim, + int64_t axis_stride, + int32_t axis_size) { + auto block = cg::this_thread_block(); + + int64_t index = cg::this_grid().block_rank(); + if (index >= size) { + return; + } + + int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); + int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); + + Op op; + T init = op.init(); + IndexValPair best{0, init}; + + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T vals[N_READS]; + auto tid = r * BLOCK_DIM + block.thread_index().z; + cub::LoadDirectBlocked( + tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); + best = op.reduce_many(best, vals, tid * N_READS); + } + + typedef cub::BlockReduce, BLOCK_DIM> BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + best = BlockReduceT(temp).Reduce(best, op); + + if (block.thread_rank() == 0) { + out[out_idx] = best.index; + } +} + +} // namespace cu + +void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgReduce::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + + // Prepare the shapes, strides and axis arguments. + Shape shape = remove_index(in.shape(), axis_); + Strides in_strides = remove_index(in.strides(), axis_); + Strides out_strides = out.ndim() == in.ndim() + ? remove_index(out.strides(), axis_) + : out.strides(); + int64_t axis_stride = in.strides()[axis_]; + int32_t axis_size = in.shape()[axis_]; + int32_t ndim = shape.size(); + + // ArgReduce. + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, { + using InType = cuda_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block_dims{1, 1, BLOCK_DIM}; + auto kernel = &cu::arg_reduce_general< + InType, + cu::ArgMax, + BLOCK_DIM, + N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = &cu::arg_reduce_general< + InType, + cu::ArgMin, + BLOCK_DIM, + N_READS>; + } + kernel<<>>( + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/iterators/strided_iterator.cuh b/mlx/backend/cuda/iterators/strided_iterator.cuh new file mode 100644 index 000000000..3ef8d66bd --- /dev/null +++ b/mlx/backend/cuda/iterators/strided_iterator.cuh @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +// RandomAccessIterator for strided access to array entries. +template +class strided_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ strided_iterator(Iterator it, Stride stride) + : super_t(it), stride_(stride) {} + + __host__ __device__ Stride stride() const { + return stride_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const strided_iterator& other) const { + return this->base() == other.base(); + } + + __host__ __device__ void advance(difference_type n) { + this->base_reference() += n * stride_; + } + + __host__ __device__ void increment() { + this->base_reference() += stride_; + } + + __host__ __device__ void decrement() { + this->base_reference() -= stride_; + } + + __host__ __device__ difference_type + distance_to(const strided_iterator& other) const { + const difference_type dist = other.base() - this->base(); + _CCCL_ASSERT( + dist % stride() == 0, + "Underlying iterator difference must be divisible by the stride"); + return dist / stride(); + } + + Stride stride_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 1b273e959..5cf19711c 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback( } NO_GPU(ArgPartition) -NO_GPU(ArgReduce) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled) NO_GPU(Convolution)