Add comments and clean up

This commit is contained in:
Angelos Katharopoulos 2025-06-21 12:44:26 -07:00
parent abdb21f27c
commit 664d8e42b8

View File

@ -77,88 +77,6 @@ struct RowReduceArgs {
} }
}; };
// template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
//__global__ void row_reduce_small(
// const T* in,
// U* out,
// size_t out_size,
// const __grid_constant__ RowReduceArgs args) {
// size_t out_idx = cg::this_grid().thread_rank();
// if (out_idx >= out_size) {
// return;
// }
//
// Op op;
//
// U total_val = ReduceInit<Op, T>::value();
// LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
//
// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(),
// args.ndim);
//
// for (size_t n = 0; n < args.non_row_reductions; n++) {
// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
// U vals[N_READS];
// cub::LoadDirectBlocked(
// r,
// make_cast_iterator<U>(in + loop.location()),
// vals,
// args.row_size,
// ReduceInit<Op, T>::value());
// total_val = op(total_val, cub::ThreadReduce(vals, op));
// }
// loop.next(args.reduce_shape.data(), args.reduce_strides.data());
// }
//
// out[out_idx] = total_val;
// }
//
// template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
//__global__ void row_reduce_small_warp(
// const T* in,
// U* out,
// size_t out_size,
// const __grid_constant__ RowReduceArgs args) {
// auto grid = cg::this_grid();
// auto block = cg::this_thread_block();
// auto warp = cg::tiled_partition<WARP_SIZE>(block);
//
// size_t out_idx = grid.thread_rank() / WARP_SIZE;
// if (out_idx >= out_size) {
// return;
// }
//
// Op op;
//
// U total_val = ReduceInit<Op, T>::value();
// LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
//
// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(),
// args.ndim);
//
// for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
// n += WARP_SIZE) {
// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
// U vals[N_READS];
// cub::LoadDirectBlocked(
// r,
// make_cast_iterator<U>(in + loop.location()),
// vals,
// args.row_size,
// ReduceInit<Op, T>::value());
// total_val = op(total_val, cub::ThreadReduce(vals, op));
// }
// loop.next(WARP_SIZE, args.reduce_shape.data(),
// args.reduce_strides.data());
// }
//
// total_val = cg::reduce(warp, total_val, op);
//
// if (warp.thread_rank() == 0) {
// out[out_idx] = total_val;
// }
// }
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1> template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
@ -286,100 +204,8 @@ __global__ void row_reduce_looped(
} }
} }
template <typename T, typename U, typename Op, int N = 4>
__global__ void reduce_initialize(U* out, size_t out_size) {
auto grid = cg::this_grid();
if (grid.thread_rank() * N + N <= out_size) {
for (int i = 0; i < N; i++) {
out[grid.thread_rank() * N + i] = ReduceInit<Op, T>::value();
}
} else {
for (int i = grid.thread_rank() * N; i < out_size; i++) {
out[i] = ReduceInit<Op, T>::value();
}
}
}
template <typename T, typename U, typename Op, int BLOCK_DIM, int N_READS = 4>
__global__ void row_reduce_atomics(
T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t reduction_idx = grid.block_rank() / out_size;
size_t out_idx = grid.block_rank() % out_size;
Op op;
U total[1];
U init = ReduceInit<Op, T>::value();
total[0] = init;
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += elem_to_loc(
reduction_idx,
args.reduce_shape.data(),
args.reduce_strides.data(),
args.reduce_ndim);
for (size_t r = 0; r < full_blocks; r++) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(), in + r * BLOCK_DIM * N_READS, vals);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + final_offset,
vals,
args.row_size - final_offset,
__cast<T, U>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
__shared__ U shared_accumulators[32];
block_reduce(block, warp, total, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
op.atomic_update(out + out_idx, total[0]);
}
}
} // namespace cu } // namespace cu
void reduce_initialize(
cu::CommandEncoder& encoder,
array& out,
Reduce::ReduceType reduce_type) {
constexpr int N_WRITES = 8;
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::reduce_initialize<T, U, OP, N_WRITES>;
auto [grid, block] =
get_launch_args(kernel, out, out.size() >= 1UL << 31, N_WRITES);
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
});
});
});
}
void row_reduce_simple( void row_reduce_simple(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@ -480,66 +306,6 @@ void row_reduce_looped(
}); });
} }
void row_reduce_atomics(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
// Initialize
reduce_initialize(encoder, out, reduce_type);
// Launch the reduction
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
args.convert_shapes_to_contiguous(x, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
if (grid.x * args.non_row_reductions < INT_MAX) {
grid.x *= args.non_row_reductions;
} else if (grid.y * args.non_row_reductions < 65536) {
grid.y *= args.non_row_reductions;
} else {
throw std::runtime_error(
"[row_reduce_atomics] Non-row reductions need to be factorized which is NYI");
}
size_t reductions = args.row_size / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_atomics<T, U, OP, 32, N_READS>;
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
kernel = cu::row_reduce_atomics<T, U, OP, THREADS, N_READS>;
block.x = THREADS;
});
// Launch
kernel<<<grid, block, 0, stream>>>(
x.data<T>(), out.data<U>(), out.size(), args);
});
});
});
}
void row_reduce( void row_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@ -547,6 +313,23 @@ void row_reduce(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
// Current row reduction options
//
// - row_reduce_simple
//
// That means that we are simply reducing across the fastest moving axis.
// We are reducing 1 or 2 rows per threadblock depending on the size of
// output.
//
// - row_reduce_looped
//
// It is a general row reduction. We are computing 1 output per
// threadblock. We read the fastest moving axis vectorized and loop over
// the rest of the axes.
//
// Notes: We opt to read as much in order as possible and leave
// transpositions as they are (contrary to our Metal backend).
// Simple row reduce means that we have 1 axis that we are reducing over and // Simple row reduce means that we have 1 axis that we are reducing over and
// it has stride 1. // it has stride 1.
if (plan.shape.size() == 1) { if (plan.shape.size() == 1) {
@ -557,62 +340,8 @@ void row_reduce(
// Make the args struct to help route to the best kernel // Make the args struct to help route to the best kernel
cu::RowReduceArgs args(in, plan, axes); cu::RowReduceArgs args(in, plan, axes);
// Let's use atomics to increase parallelism
if (false && args.row_size < 512) {
row_reduce_atomics(
encoder, in, out, reduce_type, axes, plan, std::move(args));
}
// Fallback row reduce // Fallback row reduce
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
// encoder.launch_kernel([&](cudaStream_t stream) {
// MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
// using InType = cuda_type_t<CTYPE>;
// MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
// using OutType = cu::ReduceResult<OP, InType>::type;
// MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
// constexpr size_t N_READS = 4;
// dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
// dim3 block_dims, num_blocks;
// auto kernel =
// cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
// if (args.row_size <= 64) {
// if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
// (args.non_row_reductions <= 8)) {
// block_dims.x = std::min(out_dims.x, 1024u);
// num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
// num_blocks.y = out_dims.y;
// } else {
// block_dims.x = WARP_SIZE;
// num_blocks.y = out_dims.x;
// num_blocks.z = out_dims.y;
// kernel =
// cu::row_reduce_small_warp<InType, OutType, OP, NDIM,
// N_READS>;
// }
// } else {
// size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
// num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
// MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
// num_blocks.y = out_dims.x;
// num_blocks.z = out_dims.y;
// block_dims.x = BLOCK_DIM_X;
// kernel = cu::row_reduce_looped<
// InType,
// OutType,
// OP,
// NDIM,
// BLOCK_DIM_X,
// N_READS>;
// });
// }
// kernel<<<num_blocks, block_dims, 0, stream>>>(
// in.data<InType>(), out.data<OutType>(), out.size(), args);
// });
// });
// });
// });
} }
} // namespace mlx::core } // namespace mlx::core