mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Working col reduce
This commit is contained in:
parent
664d8e42b8
commit
cc4b995723
@ -64,86 +64,6 @@ struct ColReduceArgs {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
|
||||||
__global__ void col_reduce_small(
|
|
||||||
const T* in,
|
|
||||||
U* out,
|
|
||||||
const __grid_constant__ ColReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
|
||||||
auto block = cg::this_thread_block();
|
|
||||||
|
|
||||||
int column =
|
|
||||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
|
||||||
if (column * N_READS >= args.reduction_stride) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
|
||||||
|
|
||||||
Op op;
|
|
||||||
U totals[N_READS];
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = ReduceInit<Op, T>::value();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read input to local.
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
loop.next(
|
|
||||||
block.thread_index().y,
|
|
||||||
args.reduce_shape.data(),
|
|
||||||
args.reduce_strides.data());
|
|
||||||
for (size_t r = block.thread_index().y;
|
|
||||||
r < args.non_col_reductions * args.reduction_size;
|
|
||||||
r += block.dim_threads().y) {
|
|
||||||
U vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
column,
|
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
|
||||||
vals,
|
|
||||||
args.reduction_stride,
|
|
||||||
ReduceInit<Op, T>::value());
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(vals[i], totals[i]);
|
|
||||||
}
|
|
||||||
loop.next(
|
|
||||||
block.dim_threads().y,
|
|
||||||
args.reduce_shape.data(),
|
|
||||||
args.reduce_strides.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do block reduce when each column has more than 1 element to reduce.
|
|
||||||
if (block.dim_threads().y > 1) {
|
|
||||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
|
||||||
size_t col =
|
|
||||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
shared_vals[col * N_READS + i] = totals[i];
|
|
||||||
}
|
|
||||||
block.sync();
|
|
||||||
if (block.thread_index().y == 0) {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
|
||||||
}
|
|
||||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
|
||||||
col = j * block.dim_threads().x + block.thread_index().x;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write result.
|
|
||||||
if (block.thread_index().y == 0) {
|
|
||||||
cub::StoreDirectBlocked(
|
|
||||||
column,
|
|
||||||
out + out_idx * args.reduction_stride,
|
|
||||||
totals,
|
|
||||||
args.reduction_stride);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename U,
|
typename U,
|
||||||
@ -152,67 +72,83 @@ template <
|
|||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4>
|
int N_READS = 4>
|
||||||
__global__ void col_reduce_looped(
|
__global__ void
|
||||||
const T* in,
|
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||||
U* out,
|
|
||||||
const __grid_constant__ ColReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
constexpr int n_warps = BN / N_READS;
|
constexpr int threads_per_row = BN / N_READS;
|
||||||
|
|
||||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
// Compute the indices for the tile
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
size_t tile_idx = grid.block_rank();
|
||||||
|
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
|
||||||
|
// Compute the indices for the thread within the tile
|
||||||
|
short thread_x = block.thread_rank() % threads_per_row;
|
||||||
|
short thread_y = block.thread_rank() / threads_per_row;
|
||||||
|
|
||||||
|
// Move the input pointer
|
||||||
|
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
|
||||||
|
tile_x * BN;
|
||||||
|
|
||||||
|
// Initialize the running totals
|
||||||
Op op;
|
Op op;
|
||||||
U totals[N_READS];
|
U totals[N_READS];
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read input to local.
|
|
||||||
int r = block.thread_rank() / n_warps;
|
|
||||||
int column = block.thread_rank() % n_warps;
|
|
||||||
int in_offset = grid.block_index().x * BN;
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
U vals[N_READS];
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
cub::LoadDirectBlocked(
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
column,
|
T vals[N_READS];
|
||||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
vals,
|
for (int i = 0; i < N_READS; i++) {
|
||||||
args.reduction_stride - in_offset,
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
ReduceInit<Op, T>::value());
|
}
|
||||||
for (int i = 0; i < N_READS; i++) {
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
totals[i] = op(vals[i], totals[i]);
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
thread_x,
|
||||||
|
in + loop.location(),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride - tile_x * BN,
|
||||||
|
__cast<T, U>(ReduceInit<Op, T>::value()));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do warp reduce for each output.
|
// Do warp reduce for each output.
|
||||||
constexpr int n_outputs = BN / n_warps;
|
constexpr int n_outputs = BN / threads_per_row;
|
||||||
static_assert(BM == 32 && n_outputs == N_READS);
|
static_assert(BM == 32 && n_outputs == N_READS);
|
||||||
__shared__ U shared_vals[BM * BN];
|
__shared__ U shared_vals[BM * BN];
|
||||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
short s_idx = thread_y * BN + thread_x * N_READS;
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
shared_vals[col + i] = totals[i];
|
shared_vals[s_idx + i] = totals[i];
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||||
for (int i = 0; i < n_outputs; i++) {
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
size_t out_offset = grid.block_index().x * BN;
|
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + out_idx * args.reduction_stride + out_offset,
|
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||||
totals,
|
totals,
|
||||||
args.reduction_stride - out_offset);
|
args.reduction_stride - tile_x * BN);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,6 +166,53 @@ inline auto output_grid_for_col_reduce(
|
|||||||
return get_2d_grid_dims(out_shape, out_strides);
|
return get_2d_grid_dims(out_shape, out_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void col_reduce_looped(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
cu::ColReduceArgs args) {
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
// Just a way to get out of the constness because cub doesn't like it ...
|
||||||
|
// (sigh)
|
||||||
|
array x = in;
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, args);
|
||||||
|
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
|
||||||
|
if (grid.x * extra_blocks < INT32_MAX) {
|
||||||
|
grid.x *= extra_blocks;
|
||||||
|
} else if (grid.y * extra_blocks < 65536) {
|
||||||
|
grid.y *= extra_blocks;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[col_reduce_looped] Need to factorize reduction_stride");
|
||||||
|
}
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
||||||
|
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -237,42 +220,24 @@ void col_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
|
// Current col reduce options
|
||||||
|
//
|
||||||
|
// - col_reduce_looped
|
||||||
|
//
|
||||||
|
// It is a general strided reduce. Each threadblock computes the output for
|
||||||
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
|
//
|
||||||
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
|
// leave
|
||||||
|
// transpositions as they are (contrary to our Metal backend).
|
||||||
|
//
|
||||||
|
// Moreover we need different kernels for short rows and tuning
|
||||||
|
|
||||||
|
// Make the args struct to help route to the best kernel
|
||||||
cu::ColReduceArgs args(in, plan, axes);
|
cu::ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
// Fallback col reduce
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
constexpr int N_READS = 4;
|
|
||||||
dim3 block_dims;
|
|
||||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
|
||||||
num_blocks.z = num_blocks.y;
|
|
||||||
num_blocks.y = num_blocks.x;
|
|
||||||
auto kernel =
|
|
||||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
|
||||||
if (total < 32) {
|
|
||||||
size_t stride_blocks =
|
|
||||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
|
||||||
block_dims.x = std::min(stride_blocks, 32ul);
|
|
||||||
block_dims.y = std::min(total, 8ul);
|
|
||||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
|
||||||
} else {
|
|
||||||
constexpr int BM = 32;
|
|
||||||
constexpr int BN = 32;
|
|
||||||
block_dims.x = BM * BN / N_READS;
|
|
||||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
|
||||||
kernel = cu::
|
|
||||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
|
||||||
}
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in.data<InType>(), out.data<OutType>(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user