mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 00:39:06 +08:00
Fix and refactor row-reduce (#2650)
This commit is contained in:
committed by
GitHub
parent
a393435d28
commit
e3d004fed9
@@ -7,8 +7,6 @@
|
|||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
#include <cub/block/block_reduce.cuh>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -83,7 +81,8 @@ struct RowReduceArgs {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
__global__ void
|
||||||
|
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||||
ReduceOp op;
|
ReduceOp op;
|
||||||
|
|
||||||
T vals[M][N];
|
AlignedVector<T, N> vals[M];
|
||||||
U accs[M];
|
AlignedVector<U, M> accs;
|
||||||
for (int i = 0; i < M; i++) {
|
for (int i = 0; i < M; i++) {
|
||||||
accs[i] = init;
|
accs[i] = init;
|
||||||
}
|
}
|
||||||
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||||
const size_t full_blocks = size / (block.size() * N);
|
const size_t full_blocks = size / (block.size() * N);
|
||||||
const size_t final_offset = full_blocks * (block.size() * N);
|
const size_t final_offset = full_blocks * (block.size() * N);
|
||||||
in += start_row * size;
|
in += start_row * size + block.thread_rank() * N;
|
||||||
out += start_row;
|
out += start_row;
|
||||||
|
|
||||||
if (size % N == 0) {
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
for (int k = 0; k < M; k++) {
|
||||||
for (int k = 0; k < M; k++) {
|
vals[k] = load_vector<N>(in + k * size, 0);
|
||||||
cub::LoadDirectBlockedVectorized<T, N>(
|
}
|
||||||
block.thread_rank(),
|
for (int k = 0; k < M; k++) {
|
||||||
in + k * size + r * (block.size() * N),
|
for (int j = 0; j < N; j++) {
|
||||||
vals[k]);
|
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
|
||||||
for (int k = 0; k < M; k++) {
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
block.thread_rank(),
|
|
||||||
in + k * size + r * (block.size() * N),
|
|
||||||
vals[k]);
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
in += block.size() * N;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (final_offset < size) {
|
if (final_offset < size) {
|
||||||
for (int k = 0; k < M; k++) {
|
for (int k = 0; k < M; k++) {
|
||||||
cub::LoadDirectBlocked(
|
for (int i = 0; i < N; i++) {
|
||||||
block.thread_rank(),
|
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
|
||||||
in + k * size + final_offset,
|
? in[k * size + i]
|
||||||
vals[k],
|
: cast_to<T>(init);
|
||||||
size,
|
}
|
||||||
cast_to<T>(init));
|
}
|
||||||
|
for (int k = 0; k < M; k++) {
|
||||||
for (int j = 0; j < N; j++) {
|
for (int j = 0; j < N; j++) {
|
||||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||||
}
|
}
|
||||||
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
__shared__ U shared_accumulators[32 * M];
|
__shared__ U shared_accumulators[32 * M];
|
||||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
if (block.thread_rank() == 0) {
|
||||||
if (grid.block_rank() * M + M <= n_rows) {
|
if (grid.block_rank() * M + M <= n_rows) {
|
||||||
for (int i = 0; i < M; i++) {
|
store_vector(out, 0, accs);
|
||||||
out[i] = accs[i];
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
short offset = grid.block_rank() * M + M - n_rows;
|
short offset = grid.block_rank() * M + M - n_rows;
|
||||||
for (int i = offset; i < M; i++) {
|
for (int i = offset; i < M; i++) {
|
||||||
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
typename T,
|
|
||||||
typename U,
|
|
||||||
typename Op,
|
|
||||||
int NDIM,
|
|
||||||
int BLOCK_DIM,
|
|
||||||
int N_READS = 4>
|
|
||||||
__global__ void row_reduce_looped(
|
__global__ void row_reduce_looped(
|
||||||
T* in,
|
const T* in,
|
||||||
U* out,
|
U* out,
|
||||||
size_t out_size,
|
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
@@ -185,36 +163,60 @@ __global__ void row_reduce_looped(
|
|||||||
U init = ReduceInit<Op, T>::value();
|
U init = ReduceInit<Op, T>::value();
|
||||||
total[0] = init;
|
total[0] = init;
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
const size_t full_blocks = args.row_size / (block.size() * N_READS);
|
||||||
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
const size_t final_offset = full_blocks * (block.size() * N_READS);
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
in += block.thread_rank() * N_READS;
|
||||||
|
|
||||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
// Unaligned reduce
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
if (final_offset < args.row_size) {
|
||||||
T vals[N_READS];
|
bool mask[N_READS];
|
||||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
for (int i = 0; i < N_READS; i++) {
|
||||||
block.thread_rank(),
|
mask[i] =
|
||||||
in + loop.location() + r * BLOCK_DIM * N_READS,
|
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
|
||||||
vals);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (final_offset < args.row_size) {
|
|
||||||
T vals[N_READS];
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
cub::LoadDirectBlocked(
|
const T* inlocal = in + loop.location();
|
||||||
block.thread_rank(),
|
|
||||||
in + loop.location() + final_offset,
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
vals,
|
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||||
args.row_size - final_offset,
|
for (int i = 0; i < N_READS; i++) {
|
||||||
cast_to<T>(init));
|
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||||
for (int i = 0; i < N_READS; i++) {
|
}
|
||||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
inlocal += block.size() * N_READS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
T vals[N_READS];
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aligned case
|
||||||
|
else {
|
||||||
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
|
const T* inlocal = in + loop.location();
|
||||||
|
|
||||||
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
|
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||||
|
}
|
||||||
|
inlocal += block.size() * N_READS;
|
||||||
|
}
|
||||||
|
|
||||||
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
// TODO: Maybe block.sync() here?
|
|
||||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__shared__ U shared_accumulators[32];
|
__shared__ U shared_accumulators[32];
|
||||||
@@ -234,8 +236,6 @@ void row_reduce_simple(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
constexpr int N_READS = 8;
|
|
||||||
|
|
||||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||||
// kernel.
|
// kernel.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
@@ -250,14 +250,15 @@ void row_reduce_simple(
|
|||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
using U = typename cu::ReduceResult<OP, T>::type;
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
constexpr int N_READS = 16 / sizeof(T);
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
int threads = std::min(1024UL, reductions);
|
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
warps /= 4;
|
||||||
|
warps = std::max(std::min(warps, 32), 1);
|
||||||
|
int threads = warps * WARP_SIZE;
|
||||||
dim3 block(threads, 1, 1);
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
// Pick the kernel
|
// Pick the kernel
|
||||||
@@ -267,6 +268,7 @@ void row_reduce_simple(
|
|||||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
int size = plan.shape.back();
|
int size = plan.shape.back();
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||||
@@ -282,8 +284,6 @@ void row_reduce_looped(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
cu::RowReduceArgs args) {
|
cu::RowReduceArgs args) {
|
||||||
constexpr int N_READS = 8;
|
|
||||||
|
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
@@ -295,34 +295,27 @@ void row_reduce_looped(
|
|||||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
using U = typename cu::ReduceResult<OP, T>::type;
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
constexpr int N_READS = 16 / sizeof(T);
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
args.sort_access_pattern(in, axes);
|
args.sort_access_pattern(in, axes);
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
int threads = std::min(1024UL, reductions);
|
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
warps /= 4;
|
||||||
|
warps = std::max(std::min(warps, 32), 1);
|
||||||
|
int threads = warps * WARP_SIZE;
|
||||||
dim3 block(threads, 1, 1);
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
// Pick the kernel
|
// Pick the kernel
|
||||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
|
||||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
dispatch_block_dim(threads, [&](auto threads_constant) {
|
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
|
||||||
kernel = cu::row_reduce_looped<
|
|
||||||
T,
|
|
||||||
U,
|
|
||||||
OP,
|
|
||||||
reduce_ndim.value,
|
|
||||||
threads_constant.value,
|
|
||||||
N_READS>;
|
|
||||||
block.x = threads_constant.value;
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user