mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Working row reduce looped
This commit is contained in:
parent
4d2b682a13
commit
cd523ffd9f
@ -136,61 +136,8 @@ __global__ void row_reduce_small_warp(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM_X,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||
r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM_X + block.thread_index().x,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||
__global__ void
|
||||
row_reduce_per_threadblock(T* in, U* out, size_t n_rows, int size) {
|
||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
@ -274,6 +221,72 @@ row_reduce_per_threadblock(T* in, U* out, size_t n_rows, int size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM_X,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS);
|
||||
size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS;
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + r * BLOCK_DIM_X * N_READS,
|
||||
vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(total_val, __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
if (final_offset < args.row_size) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + final_offset,
|
||||
vals,
|
||||
args.row_size - final_offset,
|
||||
__cast<T, U>(ReduceInit<Op, T>::value()));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(total_val, __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void row_reduce_simple(
|
||||
@ -287,7 +300,7 @@ void row_reduce_simple(
|
||||
|
||||
// Initialize out such that its strides match in's layout (except the fastest
|
||||
// moving axis)
|
||||
auto [_, out_strides] = shapes_without_reduction_axes(in, axes);
|
||||
auto out_strides = in.strides();
|
||||
for (auto& s : out_strides) {
|
||||
s /= plan.shape.back();
|
||||
}
|
||||
@ -321,12 +334,17 @@ void row_reduce_simple(
|
||||
size_t reductions = plan.shape.back() / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
auto kernel = cu::row_reduce_per_threadblock<T, U, OP, N_READS>;
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
||||
if (grid.x >= 1024) {
|
||||
grid.x = (grid.x + 1) / 2;
|
||||
kernel = cu::row_reduce_per_threadblock<T, U, OP, N_READS, 2>;
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), out.size(), plan.shape.back());
|
||||
});
|
||||
@ -334,6 +352,75 @@ void row_reduce_simple(
|
||||
});
|
||||
}
|
||||
|
||||
void row_reduce_looped(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Initialize out such that it matches in's layout. Basically we keep any
|
||||
// transpositions as it were and that allows us to skip finding the location
|
||||
// of the output that matches the input.
|
||||
auto out_strides = in.strides();
|
||||
for (auto ax : axes) {
|
||||
for (auto& s : out_strides) {
|
||||
if (s > in.strides(ax)) {
|
||||
s /= in.shape(ax);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
auto fl = in.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = data_size == out.size();
|
||||
out.set_data(
|
||||
allocator::malloc(out.nbytes()),
|
||||
data_size,
|
||||
out_strides,
|
||||
fl,
|
||||
allocator::free);
|
||||
|
||||
// Just a way to get out of the constness because cub doesn't like it ...
|
||||
// (sigh)
|
||||
array x = in;
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Calculate the grid and block dims
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = args.row_size / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
|
||||
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>;
|
||||
block.x = THREADS;
|
||||
});
|
||||
});
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@ -341,10 +428,14 @@ void row_reduce(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
// Simple row reduce means that we have 1 axis that we are reducing over and
|
||||
// it has stride 1.
|
||||
if (plan.shape.size() == 1) {
|
||||
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
|
||||
}
|
||||
// cu::RowReduceArgs args(in, plan, axes);
|
||||
|
||||
// Fallback row reduce
|
||||
row_reduce_looped(encoder, in, out, reduce_type, axes, plan);
|
||||
|
||||
// encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
// MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
|
Loading…
Reference in New Issue
Block a user