This commit is contained in:
Angelos Katharopoulos
2025-06-26 22:54:09 -07:00
parent d999675cb9
commit bc60a31cae
5 changed files with 49 additions and 39 deletions

View File

@@ -96,43 +96,52 @@ void all_reduce(
int blocks, threads;
size_t block_step;
array x = in;
size_t insize = in.size();
Dtype dt = in.dtype();
// Cub doesn't like const pointers for load (sigh).
void* indata = const_cast<void*>(in.data<void>());
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate);
encoder.set_input_array(x);
encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), intermediate.data<U>(), block_step, x.size());
static_cast<T*>(indata),
intermediate.data<U>(),
block_step,
insize);
});
});
});
// Set the input for the next step and recalculate the blocks
x = intermediate;
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
indata = intermediate.data<void>();
dt = intermediate.dtype();
insize = intermediate.size();
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(intermediate);
}
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), out.data<U>(), block_step, x.size());
static_cast<T*>(indata), out.data<U>(), block_step, insize);
});
});
});

View File

@@ -38,12 +38,18 @@ struct ColReduceArgs {
const array& in,
const ReductionPlan& plan,
const std::vector<int>& axes) {
using ShapeVector = decltype(plan.shape);
using StridesVector = decltype(plan.strides);
ShapeVector shape_vec;
StridesVector strides_vec;
assert(!plan.shape.empty());
reduction_size = plan.shape.back();
reduction_stride = plan.strides.back();
int64_t stride_back = 1;
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
while (!shape_vec.empty() && stride_back < reduction_stride) {
stride_back *= shape_vec.back();
shape_vec.pop_back();
@@ -54,8 +60,8 @@ struct ColReduceArgs {
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
decltype(shape_vec) sorted_shape;
decltype(strides_vec) sorted_strides;
ShapeVector sorted_shape;
StridesVector sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
@@ -206,26 +212,25 @@ void col_reduce_looped(
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
encoder.set_input_array(x);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN);
int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
});
});
});
@@ -247,8 +252,7 @@ void col_reduce(
// a subrow of the fast moving axis. For instance 32 elements.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave
// transpositions as they are (contrary to our Metal backend).
// leave transpositions as they are (contrary to our Metal backend).
//
// Moreover we need different kernels for short rows and tuning

View File

@@ -31,7 +31,6 @@ void init_reduce(
out.set_data(allocator::malloc(out.nbytes()));
}
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {

View File

@@ -146,7 +146,7 @@ inline void allocate_same_layout(
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
fl.contiguous = true;
out.set_data(
allocator::malloc(out.nbytes()),
data_size,

View File

@@ -241,20 +241,19 @@ void row_reduce_simple(
// kernel.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
// TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
encoder.set_input_array(x);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
@@ -271,7 +270,7 @@ void row_reduce_simple(
// Launch
kernel<<<grid, block, 0, stream>>>(
x.data<T>(), out.data<U>(), out.size(), plan.shape.back());
indata, out.data<U>(), out.size(), plan.shape.back());
});
});
});
@@ -291,20 +290,19 @@ void row_reduce_looped(
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
encoder.set_input_array(x);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
args.sort_access_pattern(x, axes);
args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
@@ -322,7 +320,7 @@ void row_reduce_looped(
// Launch
kernel<<<grid, block, 0, stream>>>(
x.data<T>(), out.data<U>(), out.size(), args);
indata, out.data<U>(), out.size(), args);
});
});
});