mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-11 15:06:42 +08:00
Fix and refactor row-reduce (#2650)
This commit is contained in:
committed by
GitHub
parent
a393435d28
commit
e3d004fed9
@@ -7,8 +7,6 @@
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -83,7 +81,8 @@ struct RowReduceArgs {
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
__global__ void
|
||||
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[M][N];
|
||||
U accs[M];
|
||||
AlignedVector<T, N> vals[M];
|
||||
AlignedVector<U, M> accs;
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||
const size_t full_blocks = size / (block.size() * N);
|
||||
const size_t final_offset = full_blocks * (block.size() * N);
|
||||
in += start_row * size;
|
||||
in += start_row * size + block.thread_rank() * N;
|
||||
out += start_row;
|
||||
|
||||
if (size % N == 0) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
vals[k] = load_vector<N>(in + k * size, 0);
|
||||
}
|
||||
for (int k = 0; k < M; k++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
|
||||
in += block.size() * N;
|
||||
}
|
||||
|
||||
if (final_offset < size) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + final_offset,
|
||||
vals[k],
|
||||
size,
|
||||
cast_to<T>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
|
||||
? in[k * size + i]
|
||||
: cast_to<T>(init);
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < M; k++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32 * M];
|
||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
if (grid.block_rank() * M + M <= n_rows) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
out[i] = accs[i];
|
||||
}
|
||||
store_vector(out, 0, accs);
|
||||
} else {
|
||||
short offset = grid.block_rank() * M + M - n_rows;
|
||||
for (int i = offset; i < M; i++) {
|
||||
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM,
|
||||
int N_READS = 4>
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
T* in,
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
@@ -185,36 +163,60 @@ __global__ void row_reduce_looped(
|
||||
U init = ReduceInit<Op, T>::value();
|
||||
total[0] = init;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
||||
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
||||
const size_t full_blocks = args.row_size / (block.size() * N_READS);
|
||||
const size_t final_offset = full_blocks * (block.size() * N_READS);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
in += block.thread_rank() * N_READS;
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + r * BLOCK_DIM * N_READS,
|
||||
vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
// Unaligned reduce
|
||||
if (final_offset < args.row_size) {
|
||||
bool mask[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
mask[i] =
|
||||
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
|
||||
}
|
||||
if (final_offset < args.row_size) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + final_offset,
|
||||
vals,
|
||||
args.row_size - final_offset,
|
||||
cast_to<T>(init));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
|
||||
{
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
}
|
||||
|
||||
// Aligned case
|
||||
else {
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
// TODO: Maybe block.sync() here?
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32];
|
||||
@@ -234,8 +236,6 @@ void row_reduce_simple(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||
// kernel.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -250,14 +250,15 @@ void row_reduce_simple(
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
@@ -267,6 +268,7 @@ void row_reduce_simple(
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
int size = plan.shape.back();
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||
@@ -282,8 +284,6 @@ void row_reduce_looped(
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::RowReduceArgs args) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -295,34 +295,27 @@ void row_reduce_looped(
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
|
||||
// Calculate the grid and block dims
|
||||
args.sort_access_pattern(in, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
dispatch_block_dim(threads, [&](auto threads_constant) {
|
||||
kernel = cu::row_reduce_looped<
|
||||
T,
|
||||
U,
|
||||
OP,
|
||||
reduce_ndim.value,
|
||||
threads_constant.value,
|
||||
N_READS>;
|
||||
block.x = threads_constant.value;
|
||||
});
|
||||
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
|
||||
});
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
||||
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user