diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index a289f93b8..92d82feed 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu new file mode 100644 index 000000000..ea211e7e8 --- /dev/null +++ b/mlx/backend/cuda/arg_reduce.cu @@ -0,0 +1,198 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct IndexValPair { + uint32_t index; + U val; +}; + +template +struct ArgMin { + static constexpr U init = Limits::max; + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val > current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, U (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] < best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +struct ArgMax { + static constexpr U init = Limits::min; + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val < current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, U (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] > best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +inline __device__ IndexValPair warp_shuffle_down( + const cg::thread_block_tile& g, + const IndexValPair& data, + int delta) { + return {g.shfl_down(data.index, delta), g.shfl_down(data.val, delta)}; +} + +template +__global__ void arg_reduce_general( + const T* in, + uint32_t* out, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides in_strides, + const __grid_constant__ Strides out_strides, + size_t ndim, + int64_t axis_stride, + size_t axis_size) { + // Shapes and strides *do not* contain the reduction axis. The reduction size + // and stride are provided in axis_stride and axis_size. + // + // Note: in shape == out shape with this convention. + Op op; + + // Compute the input/output index. There is one beginning and one output for + // the whole block. + auto elem = cg::this_grid().block_rank(); + auto in_idx = elem_to_loc(elem, shape.data(), in_strides.data(), ndim); + auto out_idx = elem_to_loc(elem, shape.data(), out_strides.data(), ndim); + + IndexValPair best{0, Op::init}; + + auto block = cg::this_thread_block(); + for (size_t r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + T vals[N_READS]; + auto index = r * BLOCK_DIM + block.thread_index().z; + cub::LoadDirectBlocked( + index, + strided_iterator(in + in_idx, axis_stride), + vals, + axis_size, + Op::init); + best = op.reduce_many(best, vals, index * N_READS); + } + + typedef cub::BlockReduce, BLOCK_DIM> BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + best = BlockReduceT(temp).Reduce(best, op); + + if (block.thread_rank() == 0) { + out[out_idx] = best.index; + } +} + +} // namespace cu + +void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgReduce::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + + // Prepare the shapes, strides and axis arguments. + auto in_strides = in.strides(); + auto shape = in.shape(); + auto out_strides = out.strides(); + auto axis_stride = in_strides[axis_]; + size_t axis_size = shape[axis_]; + if (out_strides.size() == in_strides.size()) { + out_strides.erase(out_strides.begin() + axis_); + } + in_strides.erase(in_strides.begin() + axis_); + shape.erase(shape.begin() + axis_); + size_t ndim = shape.size(); + + // ArgReduce. + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, { + using InType = cuda_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block_dims{1, 1, BLOCK_DIM}; + auto kernel = &cu::arg_reduce_general< + InType, + cu::ArgMax, + BLOCK_DIM, + N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = &cu::arg_reduce_general< + InType, + cu::ArgMin, + BLOCK_DIM, + N_READS>; + } + kernel<<>>( + in.data(), + out.data(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/iterators/strided_iterator.cuh b/mlx/backend/cuda/iterators/strided_iterator.cuh new file mode 100644 index 000000000..20e683fdc --- /dev/null +++ b/mlx/backend/cuda/iterators/strided_iterator.cuh @@ -0,0 +1,61 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::cu { + +// RandomAccessIterator for strided access to array entries. +template +class strided_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ strided_iterator(Iterator it, Stride stride) + : super_t(it), stride_(stride) {} + + __host__ __device__ Stride stride() const { + return stride_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const strided_iterator& other) const { + return this->base() == other.base(); + } + + __host__ __device__ void advance(difference_type n) { + this->base_reference() += n * stride_; + } + + __host__ __device__ void increment() { + this->base_reference() += stride_; + } + + __host__ __device__ void decrement() { + this->base_reference() -= stride_; + } + + __host__ __device__ difference_type + distance_to(const strided_iterator& other) const { + const difference_type dist = other.base() - this->base(); + _CCCL_ASSERT( + dist % stride() == 0, + "Underlying iterator difference must be divisible by the stride"); + return dist / stride(); + } + + Stride stride_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 1d2150da7..6eb863ee1 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -15,6 +15,33 @@ namespace mlx::core { +// Helper macros for dispatch macros (see below). +#define MLX_INTERNAL_IF_CASE(DIM, BLOCK_DIM, ...) \ + } \ + else if (_num_threads <= DIM) { \ + constexpr uint32_t BLOCK_DIM = DIM; \ + __VA_ARGS__; + +#define MLX_INTERNAL_IF_CASE_DIMS(NUM_THREADS, BLOCK_DIM, ...) \ + { \ + uint32_t _num_threads = NUM_THREADS; \ + if (false) { \ + MLX_INTERNAL_IF_CASE(32, BLOCK_DIM, __VA_ARGS__) \ + MLX_INTERNAL_IF_CASE(64, BLOCK_DIM, __VA_ARGS__) \ + MLX_INTERNAL_IF_CASE(128, BLOCK_DIM, __VA_ARGS__) \ + MLX_INTERNAL_IF_CASE(256, BLOCK_DIM, __VA_ARGS__) \ + MLX_INTERNAL_IF_CASE(512, BLOCK_DIM, __VA_ARGS__) \ + } else { \ + constexpr uint32_t BLOCK_DIM = 1024; \ + __VA_ARGS__; \ + } \ + } + +// Some kernels use CUB which requires block_dim to be known at compile-time, +// use this macro to dispatch constexpr block_dim for the num_threads. +#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ + MLX_INTERNAL_IF_CASE_DIMS(NUM_THREADS, BLOCK_DIM, __VA_ARGS__) + // Maps CPU types to CUDA types. template struct CTypeToCudaType { diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh index 6f77138e9..dad94ce55 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -20,6 +20,10 @@ namespace mlx::core::cu { // CUDA kernel utils /////////////////////////////////////////////////////////////////////////////// +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 + // To pass shape/strides to kernels via constant memory, their size must be // known at compile time. #define MAX_NDIM 8 diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 12a1746a0..ddbc4ef22 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -56,7 +56,6 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { NO_GPU(AddMM) NO_GPU(ArgPartition) -NO_GPU(ArgReduce) NO_GPU(ArgSort) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled)