Simple row reduce

This commit is contained in:
Angelos Katharopoulos 2025-06-18 23:17:16 -07:00
parent b70a964cde
commit 4d2b682a13
5 changed files with 275 additions and 129 deletions

View File

@ -21,12 +21,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(!axes_.empty()); assert(!axes_.empty());
assert(out.size() != in.size()); assert(out.size() != in.size());
// out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
// encoder.set_input_array(in);
// encoder.set_output_array(out);
if (in.size() == 0) { if (in.size() == 0) {
throw std::runtime_error("Should never reach here."); throw std::runtime_error("Should never reach here.");
@ -50,11 +46,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce(encoder, in, out, reduce_type_, axes_, plan); row_reduce(encoder, in, out, reduce_type_, axes_, plan);
return; return;

View File

@ -13,26 +13,6 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
namespace {
// TODO: Should make a custom complex type
template <typename U, typename T>
inline __device__ U __cast(T x) {
return static_cast<U>(x);
}
template <>
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
return x.x != 0 && x.y != 0;
}
template <>
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
} // namespace
template <typename T, typename U, typename ReduceOp, int N = 4> template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
@ -54,8 +34,8 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) { for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals); cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) {
accs[i] = op(accs[i], __cast<U, T>(vals[i])); accs[j] = op(accs[j], __cast<U, T>(vals[j]));
} }
} }
@ -74,15 +54,17 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
} }
accs[0] = cg::reduce(warp, accs[0], op); accs[0] = cg::reduce(warp, accs[0], op);
__shared__ U shared_accumulators[32]; if (warp.meta_group_size() > 1) {
if (warp.thread_rank() == 0) { __shared__ U shared_accumulators[32];
shared_accumulators[warp.meta_group_rank()] = accs[0]; if (warp.thread_rank() == 0) {
shared_accumulators[warp.meta_group_rank()] = accs[0];
}
block.sync();
accs[0] = (warp.thread_rank() < warp.meta_group_size())
? shared_accumulators[warp.thread_rank()]
: init;
accs[0] = cg::reduce(warp, accs[0], op);
} }
block.sync();
accs[0] = (warp.thread_rank() < warp.meta_group_size())
? shared_accumulators[warp.thread_rank()]
: init;
accs[0] = cg::reduce(warp, accs[0], op);
if (block.thread_rank() == 0) { if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0]; out[grid.block_rank()] = accs[0];
@ -96,7 +78,7 @@ void all_reduce(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType reduce_type) { Reduce::ReduceType reduce_type) {
constexpr int N_READS = 4; constexpr int N_READS = 8;
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
@ -118,14 +100,13 @@ void all_reduce(
} }
size_t reductions_per_block = std::max( size_t reductions_per_block = std::max(
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks); static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
size_t block_step = reductions_per_block * N_READS; size_t block_step = reductions_per_block * N;
return std::make_tuple(blocks, threads, block_step); return std::make_tuple(blocks, threads, block_step);
}; };
int blocks, threads; int blocks, threads;
size_t block_step; size_t block_step;
bool large = in.size() > N_READS * 1024;
array x = in; array x = in;
// Large array so allocate an intermediate and accumulate there // Large array so allocate an intermediate and accumulate there

View File

@ -3,49 +3,10 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
template <size_t N>
struct uint_by_size;
template <>
struct uint_by_size<2> {
using type = uint16_t;
};
template <>
struct uint_by_size<4> {
using type = uint32_t;
};
template <>
struct uint_by_size<8> {
using type = unsigned long long int;
};
template <typename T, typename Op>
__device__ void atomic_reduce(T* x, T y) {
if constexpr (sizeof(T) == 1) {
using U = uint16_t;
U* x_int = (U*)((char*)x - ((size_t)x % 2));
int shift = ((char*)x - (char*)x_int) * 8;
int mask = 0xff << shift;
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
new_val = (old_val & ~mask) | (result << shift);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
} else {
using U = typename uint_by_size<sizeof(T)>::type;
U* x_int = (U*)(x);
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(*((T*)&old_val), y);
new_val = *((U*)&result);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
}
}
// Reduce ops. // Reduce ops.
struct And { struct And {
__device__ __forceinline__ bool operator()(bool a, bool b) { __device__ __forceinline__ bool operator()(bool a, bool b) {

View File

@ -0,0 +1,65 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device/utils.cuh"
namespace mlx::core::cu {
template <size_t N>
struct uint_by_size;
template <>
struct uint_by_size<2> {
using type = uint16_t;
};
template <>
struct uint_by_size<4> {
using type = uint32_t;
};
template <>
struct uint_by_size<8> {
using type = unsigned long long int;
};
template <typename T, typename Op>
__device__ void atomic_reduce(T* x, T y) {
if constexpr (sizeof(T) == 1) {
using U = uint16_t;
U* x_int = (U*)((char*)x - ((size_t)x % 2));
int shift = ((char*)x - (char*)x_int) * 8;
int mask = 0xff << shift;
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
new_val = (old_val & ~mask) | (result << shift);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
} else {
using U = typename uint_by_size<sizeof(T)>::type;
U* x_int = (U*)(x);
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(*((T*)&old_val), y);
new_val = *((U*)&result);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
}
}
// TODO: Should make a custom complex type
template <typename U, typename T>
inline __device__ U __cast(T x) {
return static_cast<U>(x);
}
template <>
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
return x.x != 0 && x.y != 0;
}
template <>
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
} // namespace mlx::core::cu

View File

@ -188,8 +188,152 @@ __global__ void row_reduce_looped(
} }
} }
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void
row_reduce_per_threadblock(T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
T vals[M][N];
U accs[M];
for (int i = 0; i < M; i++) {
accs[i] = init;
}
const size_t start_row =
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
in += start_row * size;
out += start_row;
int i = 0;
for (; i + block.size() * N <= size; i += block.size() * N) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(), in + k * size + i, vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
if (size > i) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + i,
vals[k],
size,
__cast<T, U>(init));
for (int j = 0; i < N; i++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
for (int i = 0; i < M; i++) {
accs[i] = cg::reduce(warp, accs[i], op);
}
if (warp.meta_group_size() > 1) {
__shared__ U shared_accumulators[32 * M];
if (warp.thread_rank() == 0) {
for (int i = 0; i < M; i++) {
shared_accumulators[warp.meta_group_rank() * M + i] = accs[i];
}
}
block.sync();
if (warp.thread_rank() < warp.meta_group_size()) {
for (int i = 0; i < M; i++) {
accs[i] = shared_accumulators[warp.thread_rank() * M + i];
}
} else {
for (int i = 0; i < M; i++) {
accs[i] = init;
}
}
for (int i = 0; i < M; i++) {
accs[i] = cg::reduce(warp, accs[i], op);
}
}
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
} else {
short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) {
out[i] = accs[i];
}
}
}
}
} // namespace cu } // namespace cu
void row_reduce_simple(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Initialize out such that its strides match in's layout (except the fastest
// moving axis)
auto [_, out_strides] = shapes_without_reduction_axes(in, axes);
for (auto& s : out_strides) {
s /= plan.shape.back();
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
fl,
allocator::free);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
// TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
// Calculate the grid and block dims
size_t reductions = plan.shape.back() / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions);
dim3 block(threads, 1, 1);
auto kernel = cu::row_reduce_per_threadblock<T, U, OP, N_READS>;
if (grid.x >= 1024) {
grid.x = (grid.x + 1) / 2;
kernel = cu::row_reduce_per_threadblock<T, U, OP, N_READS, 2>;
}
kernel<<<grid, block, 0, stream>>>(
x.data<T>(), out.data<U>(), out.size(), plan.shape.back());
});
});
});
}
void row_reduce( void row_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@ -197,54 +341,58 @@ void row_reduce(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
cu::RowReduceArgs args(in, plan, axes); if (plan.shape.size() == 1) {
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
}
// cu::RowReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) { // encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { // MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>; // using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { // MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type; // using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { // MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr size_t N_READS = 4; // constexpr size_t N_READS = 4;
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); // dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims, num_blocks; // dim3 block_dims, num_blocks;
auto kernel = // auto kernel =
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>; // cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
if (args.row_size <= 64) { // if (args.row_size <= 64) {
if ((args.non_row_reductions < 32 && args.row_size <= 8) || // if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
(args.non_row_reductions <= 8)) { // (args.non_row_reductions <= 8)) {
block_dims.x = std::min(out_dims.x, 1024u); // block_dims.x = std::min(out_dims.x, 1024u);
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); // num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
num_blocks.y = out_dims.y; // num_blocks.y = out_dims.y;
} else { // } else {
block_dims.x = WARP_SIZE; // block_dims.x = WARP_SIZE;
num_blocks.y = out_dims.x; // num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y; // num_blocks.z = out_dims.y;
kernel = // kernel =
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>; // cu::row_reduce_small_warp<InType, OutType, OP, NDIM,
} // N_READS>;
} else { // }
size_t num_threads = cuda::ceil_div(args.row_size, N_READS); // } else {
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; // size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { // num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
num_blocks.y = out_dims.x; // MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
num_blocks.z = out_dims.y; // num_blocks.y = out_dims.x;
block_dims.x = BLOCK_DIM_X; // num_blocks.z = out_dims.y;
kernel = cu::row_reduce_looped< // block_dims.x = BLOCK_DIM_X;
InType, // kernel = cu::row_reduce_looped<
OutType, // InType,
OP, // OutType,
NDIM, // OP,
BLOCK_DIM_X, // NDIM,
N_READS>; // BLOCK_DIM_X,
}); // N_READS>;
} // });
kernel<<<num_blocks, block_dims, 0, stream>>>( // }
in.data<InType>(), out.data<OutType>(), out.size(), args); // kernel<<<num_blocks, block_dims, 0, stream>>>(
}); // in.data<InType>(), out.data<OutType>(), out.size(), args);
}); // });
}); // });
}); // });
// });
} }
} // namespace mlx::core } // namespace mlx::core