CUDA backend: argreduce (#2270)

This commit is contained in:
Cheng 2025-06-12 05:26:17 +09:00 committed by GitHub
parent c9fa68664a
commit ccf78f566c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 250 additions and 1 deletions

View File

@ -6,6 +6,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu

View File

@ -0,0 +1,189 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
struct IndexValPair {
uint32_t index;
T val;
};
template <typename T>
struct ArgMin {
constexpr __device__ T init() {
return Limits<T>::max();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T>
struct ArgMax {
constexpr __device__ T init() {
return Limits<T>::min();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
__global__ void arg_reduce_general(
const T* in,
uint32_t* out,
size_t size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides in_strides,
const __grid_constant__ Strides out_strides,
int32_t ndim,
int64_t axis_stride,
int32_t axis_size) {
auto block = cg::this_thread_block();
int64_t index = cg::this_grid().block_rank();
if (index >= size) {
return;
}
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().z;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
best = BlockReduceT(temp).Reduce(best, op);
if (block.thread_rank() == 0) {
out[out_idx] = best.index;
}
}
} // namespace cu
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
Strides in_strides = remove_index(in.strides(), axis_);
Strides out_strides = out.ndim() == in.ndim()
? remove_index(out.strides(), axis_)
: out.strides();
int64_t axis_stride = in.strides()[axis_];
int32_t axis_size = in.shape()[axis_];
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
using InType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{1, 1, BLOCK_DIM};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,
BLOCK_DIM,
N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = &cu::arg_reduce_general<
InType,
cu::ArgMin<InType>,
BLOCK_DIM,
N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,60 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
}
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(BlockMaskedMM)
NO_GPU_MULTI(Compiled)
NO_GPU(Convolution)