Fix and refactor row-reduce (#2650)

This commit is contained in:
Angelos Katharopoulos
2025-10-07 01:51:08 -07:00
committed by GitHub
parent a393435d28
commit e3d004fed9

View File

@@ -7,8 +7,6 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -83,7 +81,8 @@ struct RowReduceArgs {
};
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
__global__ void
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
T vals[M][N];
U accs[M];
AlignedVector<T, N> vals[M];
AlignedVector<U, M> accs;
for (int i = 0; i < M; i++) {
accs[i] = init;
}
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size;
in += start_row * size + block.thread_rank() * N;
out += start_row;
if (size % N == 0) {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
vals[k] = load_vector<N>(in + k * size, 0);
}
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
in += block.size() * N;
}
if (final_offset < size) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + final_offset,
vals[k],
size,
cast_to<T>(init));
for (int i = 0; i < N; i++) {
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
? in[k * size + i]
: cast_to<T>(init);
}
}
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
}
__shared__ U shared_accumulators[32 * M];
block_reduce(block, warp, accs, shared_accumulators, op, init);
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
store_vector(out, 0, accs);
} else {
short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) {
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM,
int N_READS = 4>
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_looped(
T* in,
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
@@ -185,36 +163,60 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
const size_t full_blocks = args.row_size / (block.size() * N_READS);
const size_t final_offset = full_blocks * (block.size() * N_READS);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS;
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < full_blocks; r++) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(),
in + loop.location() + r * BLOCK_DIM * N_READS,
vals);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
// Unaligned reduce
if (final_offset < args.row_size) {
bool mask[N_READS];
for (int i = 0; i < N_READS; i++) {
mask[i] =
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
}
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
cast_to<T>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
{
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
}
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
// Aligned case
else {
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
__shared__ U shared_accumulators[32];
@@ -234,8 +236,6 @@ void row_reduce_simple(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
@@ -250,14 +250,15 @@ void row_reduce_simple(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 16 / sizeof(T);
// Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
@@ -267,6 +268,7 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
T* indata = const_cast<T*>(in.data<T>());
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
@@ -282,8 +284,6 @@ void row_reduce_looped(
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -295,34 +295,27 @@ void row_reduce_looped(
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 16 / sizeof(T);
// Calculate the grid and block dims
args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
dispatch_block_dim(threads, [&](auto threads_constant) {
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
});
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
});
});
}