Working col reduce

This commit is contained in:
Angelos Katharopoulos 2025-06-21 23:39:40 -07:00
parent 664d8e42b8
commit cc4b995723

View File

@ -64,86 +64,6 @@ struct ColReduceArgs {
} }
}; };
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
int column =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
if (column * N_READS >= args.reduction_stride) {
return;
}
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(
block.thread_index().y,
args.reduce_shape.data(),
args.reduce_strides.data());
for (size_t r = block.thread_index().y;
r < args.non_col_reductions * args.reduction_size;
r += block.dim_threads().y) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location()),
vals,
args.reduction_stride,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(
block.dim_threads().y,
args.reduce_shape.data(),
args.reduce_strides.data());
}
// Do block reduce when each column has more than 1 element to reduce.
if (block.dim_threads().y > 1) {
__shared__ U shared_vals[32 * 8 * N_READS];
size_t col =
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
shared_vals[col * N_READS + i] = totals[i];
}
block.sync();
if (block.thread_index().y == 0) {
for (int i = 0; i < N_READS; i++) {
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
}
for (int j = 1; j < block.dim_threads().y; j++) {
col = j * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
}
}
}
}
// Write result.
if (block.thread_index().y == 0) {
cub::StoreDirectBlocked(
column,
out + out_idx * args.reduction_stride,
totals,
args.reduction_stride);
}
}
template < template <
typename T, typename T,
typename U, typename U,
@ -152,67 +72,83 @@ template <
int BM, int BM,
int BN, int BN,
int N_READS = 4> int N_READS = 4>
__global__ void col_reduce_looped( __global__ void
const T* in, col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int n_warps = BN / N_READS; constexpr int threads_per_row = BN / N_READS;
int out_idx = grid.block_rank() / grid.dim_blocks().x; // Compute the indices for the tile
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
short thread_y = block.thread_rank() / threads_per_row;
// Move the input pointer
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
tile_x * BN;
// Initialize the running totals
Op op; Op op;
U totals[N_READS]; U totals[N_READS];
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value(); totals[i] = ReduceInit<Op, T>::value();
} }
// Read input to local.
int r = block.thread_rank() / n_warps;
int column = block.thread_rank() % n_warps;
int in_offset = grid.block_index().x * BN;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
for (; r < args.non_col_reductions * args.reduction_size; r += BM) { size_t total = args.non_col_reductions * args.reduction_size;
U vals[N_READS]; if (tile_x * BN + BN <= args.reduction_stride) {
cub::LoadDirectBlocked( for (size_t r = thread_y; r < total; r += BM) {
column, T vals[N_READS];
make_cast_iterator<U>(in + loop.location() + in_offset), cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
vals,
args.reduction_stride - in_offset,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]); totals[i] = op(totals[i], __cast<U, T>(vals[i]));
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
} else {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
in + loop.location(),
vals,
args.reduction_stride - tile_x * BN,
__cast<T, U>(ReduceInit<Op, T>::value()));
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
}
// Do warp reduce for each output. // Do warp reduce for each output.
constexpr int n_outputs = BN / n_warps; constexpr int n_outputs = BN / threads_per_row;
static_assert(BM == 32 && n_outputs == N_READS); static_assert(BM == 32 && n_outputs == N_READS);
__shared__ U shared_vals[BM * BN]; __shared__ U shared_vals[BM * BN];
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; short s_idx = thread_y * BN + thread_x * N_READS;
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
shared_vals[col + i] = totals[i]; shared_vals[s_idx + i] = totals[i];
} }
block.sync(); block.sync();
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
totals[i] = cg::reduce(warp, shared_vals[col + i], op); totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
} }
// Write result. // Write result.
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
size_t out_offset = grid.block_index().x * BN;
cub::StoreDirectBlocked( cub::StoreDirectBlocked(
warp.meta_group_rank(), warp.meta_group_rank(),
out + out_idx * args.reduction_stride + out_offset, out + tile_y * args.reduction_stride + tile_x * BN,
totals, totals,
args.reduction_stride - out_offset); args.reduction_stride - tile_x * BN);
} }
} }
@ -230,6 +166,53 @@ inline auto output_grid_for_col_reduce(
return get_2d_grid_dims(out_shape, out_strides); return get_2d_grid_dims(out_shape, out_strides);
} }
void col_reduce_looped(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args);
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
if (grid.x * extra_blocks < INT32_MAX) {
grid.x *= extra_blocks;
} else if (grid.y * extra_blocks < 65536) {
grid.y *= extra_blocks;
} else {
throw std::runtime_error(
"[col_reduce_looped] Need to factorize reduction_stride");
}
int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
});
});
});
});
}
void col_reduce( void col_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@ -237,42 +220,24 @@ void col_reduce(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
// Current col reduce options
//
// - col_reduce_looped
//
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave
// transpositions as they are (contrary to our Metal backend).
//
// Moreover we need different kernels for short rows and tuning
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes); cu::ColReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) { // Fallback col reduce
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr int N_READS = 4;
dim3 block_dims;
dim3 num_blocks = output_grid_for_col_reduce(out, args);
num_blocks.z = num_blocks.y;
num_blocks.y = num_blocks.x;
auto kernel =
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
size_t total = args.non_col_reductions * args.reduction_size;
if (total < 32) {
size_t stride_blocks =
cuda::ceil_div(args.reduction_stride, N_READS);
block_dims.x = std::min(stride_blocks, 32ul);
block_dims.y = std::min(total, 8ul);
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
} else {
constexpr int BM = 32;
constexpr int BN = 32;
block_dims.x = BM * BN / N_READS;
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
kernel = cu::
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), args);
});
});
});
});
} }
} // namespace mlx::core } // namespace mlx::core