mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA backend: argreduce
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
|
||||
198
mlx/backend/cuda/arg_reduce.cu
Normal file
198
mlx/backend/cuda/arg_reduce.cu
Normal file
@@ -0,0 +1,198 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename U>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
U val;
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct ArgMin {
|
||||
static constexpr U init = Limits<U>::max;
|
||||
|
||||
__device__ IndexValPair<U> operator()(
|
||||
const IndexValPair<U>& best,
|
||||
const IndexValPair<U>& current) {
|
||||
if (best.val > current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<U>
|
||||
reduce_many(IndexValPair<U> best, U (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct ArgMax {
|
||||
static constexpr U init = Limits<U>::min;
|
||||
|
||||
__device__ IndexValPair<U> operator()(
|
||||
const IndexValPair<U>& best,
|
||||
const IndexValPair<U>& current) {
|
||||
if (best.val < current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<U>
|
||||
reduce_many(IndexValPair<U> best, U (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
inline __device__ IndexValPair<U> warp_shuffle_down(
|
||||
const cg::thread_block_tile<WARP_SIZE>& g,
|
||||
const IndexValPair<U>& data,
|
||||
int delta) {
|
||||
return {g.shfl_down(data.index, delta), g.shfl_down(data.val, delta)};
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void arg_reduce_general(
|
||||
const T* in,
|
||||
uint32_t* out,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides in_strides,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
size_t ndim,
|
||||
int64_t axis_stride,
|
||||
size_t axis_size) {
|
||||
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||
// and stride are provided in axis_stride and axis_size.
|
||||
//
|
||||
// Note: in shape == out shape with this convention.
|
||||
Op op;
|
||||
|
||||
// Compute the input/output index. There is one beginning and one output for
|
||||
// the whole block.
|
||||
auto elem = cg::this_grid().block_rank();
|
||||
auto in_idx = elem_to_loc(elem, shape.data(), in_strides.data(), ndim);
|
||||
auto out_idx = elem_to_loc(elem, shape.data(), out_strides.data(), ndim);
|
||||
|
||||
IndexValPair<T> best{0, Op::init};
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
for (size_t r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
T vals[N_READS];
|
||||
auto index = r * BLOCK_DIM + block.thread_index().z;
|
||||
cub::LoadDirectBlocked(
|
||||
index,
|
||||
strided_iterator(in + in_idx, axis_stride),
|
||||
vals,
|
||||
axis_size,
|
||||
Op::init);
|
||||
best = op.reduce_many(best, vals, index * N_READS);
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
best = BlockReduceT(temp).Reduce(best, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
auto in_strides = in.strides();
|
||||
auto shape = in.shape();
|
||||
auto out_strides = out.strides();
|
||||
auto axis_stride = in_strides[axis_];
|
||||
size_t axis_size = shape[axis_];
|
||||
if (out_strides.size() == in_strides.size()) {
|
||||
out_strides.erase(out_strides.begin() + axis_);
|
||||
}
|
||||
in_strides.erase(in_strides.begin() + axis_);
|
||||
shape.erase(shape.begin() + axis_);
|
||||
size_t ndim = shape.size();
|
||||
|
||||
// ArgReduce.
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMin<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(),
|
||||
out.data<uint32_t>(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
61
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
61
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <thrust/iterator/iterator_facade.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// RandomAccessIterator for strided access to array entries.
|
||||
template <typename Iterator, typename Stride = int64_t>
|
||||
class strided_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||
: super_t(it), stride_(stride) {}
|
||||
|
||||
__host__ __device__ Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||
return this->base() == other.base();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->base_reference() += n * stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->base_reference() += stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->base_reference() -= stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const strided_iterator& other) const {
|
||||
const difference_type dist = other.base() - this->base();
|
||||
_CCCL_ASSERT(
|
||||
dist % stride() == 0,
|
||||
"Underlying iterator difference must be divisible by the stride");
|
||||
return dist / stride();
|
||||
}
|
||||
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -15,6 +15,33 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Helper macros for dispatch macros (see below).
|
||||
#define MLX_INTERNAL_IF_CASE(DIM, BLOCK_DIM, ...) \
|
||||
} \
|
||||
else if (_num_threads <= DIM) { \
|
||||
constexpr uint32_t BLOCK_DIM = DIM; \
|
||||
__VA_ARGS__;
|
||||
|
||||
#define MLX_INTERNAL_IF_CASE_DIMS(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
{ \
|
||||
uint32_t _num_threads = NUM_THREADS; \
|
||||
if (false) { \
|
||||
MLX_INTERNAL_IF_CASE(32, BLOCK_DIM, __VA_ARGS__) \
|
||||
MLX_INTERNAL_IF_CASE(64, BLOCK_DIM, __VA_ARGS__) \
|
||||
MLX_INTERNAL_IF_CASE(128, BLOCK_DIM, __VA_ARGS__) \
|
||||
MLX_INTERNAL_IF_CASE(256, BLOCK_DIM, __VA_ARGS__) \
|
||||
MLX_INTERNAL_IF_CASE(512, BLOCK_DIM, __VA_ARGS__) \
|
||||
} else { \
|
||||
constexpr uint32_t BLOCK_DIM = 1024; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Some kernels use CUB which requires block_dim to be known at compile-time,
|
||||
// use this macro to dispatch constexpr block_dim for the num_threads.
|
||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
MLX_INTERNAL_IF_CASE_DIMS(NUM_THREADS, BLOCK_DIM, __VA_ARGS__)
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
|
||||
@@ -20,6 +20,10 @@ namespace mlx::core::cu {
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
@@ -56,7 +56,6 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
NO_GPU(AddMM)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
|
||||
Reference in New Issue
Block a user