mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
CUDA backend: reduce (#2269)
This commit is contained in:
parent
c35f4d089a
commit
c9fa68664a
@ -21,6 +21,10 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
|
@ -47,6 +47,31 @@ namespace mlx::core {
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
{ \
|
||||
uint32_t _num_threads = NUM_THREADS; \
|
||||
if (_num_threads <= WARP_SIZE) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
|
@ -9,6 +9,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
@ -19,6 +21,10 @@ namespace mlx::core::cu {
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
#define MAX_NDIM 8
|
||||
@ -26,6 +32,94 @@ namespace mlx::core::cu {
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct Limits {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, __half> ||
|
||||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#else
|
||||
return -cuda::std::numeric_limits<float>::infinity();
|
||||
#endif
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#else
|
||||
return cuda::std::numeric_limits<float>::lowest();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr __host__ __device__ bool max() {
|
||||
return true;
|
||||
}
|
||||
static constexpr __host__ __device__ bool min() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<cuComplex> {
|
||||
static constexpr __host__ __device__ cuComplex max() {
|
||||
return {Limits<float>::max(), Limits<float>::max()};
|
||||
}
|
||||
static constexpr __host__ __device__ cuComplex min() {
|
||||
return {Limits<float>::min(), Limits<float>::min()};
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
OffsetT offset{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int) {}
|
||||
|
||||
__device__ void next(const int*, const int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -91,7 +91,6 @@ NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
|
82
mlx/backend/cuda/reduce.cu
Normal file
82
mlx/backend/cuda/reduce.cu
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Reduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
array in = inputs[0];
|
||||
|
||||
// Make sure no identity reductions trickle down here.
|
||||
assert(!axes_.empty());
|
||||
assert(out.size() != in.size());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Fill out with init value.
|
||||
if (in.size() == 0) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
thrust::fill_n(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
out.data_size(),
|
||||
cu::ReduceInit<OP, InType>::value());
|
||||
});
|
||||
});
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
// recompute the plan.
|
||||
if (plan.type == GeneralReduce) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
if ((plan.type == ContiguousAllReduce) ||
|
||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousStridedReduce ||
|
||||
plan.type == GeneralStridedReduce) {
|
||||
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::runtime_error("No plan reached in reduce.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
@ -0,0 +1,278 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct ColReduceArgs {
|
||||
// The size of the contiguous column reduction.
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes (including last dimension).
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_col_reductions;
|
||||
|
||||
ColReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size();
|
||||
|
||||
non_col_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||
non_col_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int column =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
if (column * N_READS >= args.reduction_stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(
|
||||
block.thread_index().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
for (size_t r = block.thread_index().y;
|
||||
r < args.non_col_reductions * args.reduction_size;
|
||||
r += block.dim_threads().y) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.reduction_stride,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(
|
||||
block.dim_threads().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do block reduce when each column has more than 1 element to reduce.
|
||||
if (block.dim_threads().y > 1) {
|
||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||
size_t col =
|
||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col * N_READS + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
if (block.thread_index().y == 0) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||
}
|
||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||
col = j * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (block.thread_index().y == 0) {
|
||||
cub::StoreDirectBlocked(
|
||||
column,
|
||||
out + out_idx * args.reduction_stride,
|
||||
totals,
|
||||
args.reduction_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void col_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
int r = block.thread_rank() / n_warps;
|
||||
int column = block.thread_rank() % n_warps;
|
||||
int in_offset = grid.block_index().x * BN;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||
vals,
|
||||
args.reduction_stride - in_offset,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do warp reduce for each output.
|
||||
constexpr int n_outputs = BN / n_warps;
|
||||
static_assert(BM == 32 && n_outputs == N_READS);
|
||||
__shared__ U shared_vals[BM * BN];
|
||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
size_t out_offset = grid.block_index().x * BN;
|
||||
cub::StoreDirectBlocked(
|
||||
warp.meta_group_rank(),
|
||||
out + out_idx * args.reduction_stride + out_offset,
|
||||
totals,
|
||||
args.reduction_stride - out_offset);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr int N_READS = 4;
|
||||
dim3 block_dims;
|
||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||
num_blocks.z = num_blocks.y;
|
||||
num_blocks.y = num_blocks.x;
|
||||
auto kernel =
|
||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (total < 32) {
|
||||
size_t stride_blocks =
|
||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||
block_dims.x = std::min(stride_blocks, 32ul);
|
||||
block_dims.y = std::min(total, 8ul);
|
||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||
} else {
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
block_dims.x = BM * BN / N_READS;
|
||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||
kernel = cu::
|
||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
@ -0,0 +1,74 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Dispatch dynamic ndim to constexpr.
|
||||
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||
if (ndim == 1) { \
|
||||
constexpr uint32_t NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
} else if (ndim == 2) { \
|
||||
constexpr uint32_t NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t NDIM = 5; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Dispatch reduce ops to constexpr.
|
||||
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||
if (REDUCE == Reduce::ReduceType::And) { \
|
||||
using OP = cu::And; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||
using OP = cu::Or; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||
using OP = cu::Sum; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||
using OP = cu::Prod; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||
using OP = cu::Max; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||
using OP = cu::Min; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
}
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
} // namespace mlx::core
|
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
// Traits to get the result type of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceResult;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<And, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Or, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Sum, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Prod, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Min, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Max, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
// Traits to get the init value of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceInit;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<And, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Or, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Sum, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{0, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Sum, T>::type{0};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Prod, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 1};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Min, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Max, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
@ -0,0 +1,250 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct RowReduceArgs {
|
||||
// The size of the row being reduced, i.e. the size of last dimension.
|
||||
int row_size;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes excluding last dimension.
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of rows we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_row_reductions;
|
||||
|
||||
RowReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
row_size = plan.shape.back();
|
||||
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size() - 1;
|
||||
|
||||
non_row_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim; i++) {
|
||||
non_row_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
size_t out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small_warp(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||
n += WARP_SIZE) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
total_val = cg::reduce(warp, total_val, op);
|
||||
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM_X,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||
r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM_X + block.thread_index().x,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr size_t N_READS = 4;
|
||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims, num_blocks;
|
||||
auto kernel =
|
||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
if (args.row_size <= 64) {
|
||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||
(args.non_row_reductions <= 8)) {
|
||||
block_dims.x = std::min(out_dims.x, 1024u);
|
||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||
num_blocks.y = out_dims.y;
|
||||
} else {
|
||||
block_dims.x = WARP_SIZE;
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
kernel =
|
||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||
}
|
||||
} else {
|
||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
block_dims.x = BLOCK_DIM_X;
|
||||
kernel = cu::row_reduce_looped<
|
||||
InType,
|
||||
OutType,
|
||||
OP,
|
||||
NDIM,
|
||||
BLOCK_DIM_X,
|
||||
N_READS>;
|
||||
});
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <cub/device/device_segmented_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename... Args>
|
||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct MultiplyOp {
|
||||
int factor;
|
||||
__device__ int operator()(int i) {
|
||||
return i * factor;
|
||||
}
|
||||
};
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||
thrust::device_pointer_cast(in.data<InType>()));
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
auto init = cu::ReduceInit<OP, InType>::value();
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
cub_all_reduce(
|
||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||
} else if (plan.type == ContiguousReduce) {
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||
cub_segmented_reduce(
|
||||
encoder,
|
||||
in_iter,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
offsets,
|
||||
offsets + 1,
|
||||
OP(),
|
||||
init,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user