More fixes for all reductions

This commit is contained in:
Angelos Katharopoulos 2025-06-24 01:17:30 -07:00
parent 8bd4bf2393
commit a57a75b992
6 changed files with 93 additions and 57 deletions

View File

@ -46,7 +46,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
broadcasted = in.strides(i) == 0;
}
}
if (plan.type == GeneralReduce || broadcasted) {
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy);

View File

@ -33,18 +33,17 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
size_t end = start + block_step;
size_t check = min(end, size);
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
size_t i = start;
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
}
}
if (end > size) {
size_t offset = end - block.size() * N;
int block_end = size - offset;
if (i < check) {
cub::LoadDirectBlocked(
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
}
@ -70,24 +69,27 @@ void all_reduce(
out.set_data(allocator::malloc(out.nbytes()));
auto get_args = [](size_t size, int N) {
size_t reductions = size / N;
int threads = 512;
size_t full_blocks = (reductions + threads - 1) / threads;
int threads = std::min(512UL, (size + N - 1) / N);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int reductions_per_step = threads * N;
size_t steps_needed =
(size + reductions_per_step - 1) / reductions_per_step;
int blocks;
if (full_blocks < 32) {
if (steps_needed < 32) {
blocks = 1;
} else if (full_blocks < 128) {
} else if (steps_needed < 128) {
blocks = 32;
} else if (full_blocks < 512) {
} else if (steps_needed < 512) {
blocks = 128;
} else if (full_blocks < 1024) {
} else if (steps_needed < 1024) {
blocks = 512;
} else {
blocks = 1024;
}
size_t reductions_per_block = std::max(
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
size_t block_step = reductions_per_block * N;
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
size_t block_step = steps_per_block * reductions_per_step;
return std::make_tuple(blocks, threads, block_step);
};
@ -99,7 +101,6 @@ void all_reduce(
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
if (blocks > 1) {
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate);

View File

@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc.
#include <numeric>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
@ -47,8 +49,19 @@ struct ColReduceArgs {
shape_vec.pop_back();
strides_vec.pop_back();
}
std::vector<int> indices(shape_vec.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
decltype(shape_vec) sorted_shape;
decltype(strides_vec) sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
}
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
collapse_contiguous_dims(sorted_shape, sorted_strides);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
@ -167,16 +180,18 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args) {
Shape out_shape;
Strides out_strides;
for (int i = 0; i < out.ndim(); i++) {
if (out.strides(i) >= args.reduction_stride) {
out_shape.push_back(out.shape(i));
out_strides.push_back(out.strides(i));
const cu::ColReduceArgs& args,
int bn) {
int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks;
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
}
}
return get_2d_grid_dims(out_shape, out_strides);
gx = cuda::ceil_div(n_blocks, gy);
return dim3(gx, gy, 1);
}
void col_reduce_looped(
@ -207,16 +222,7 @@ void col_reduce_looped(
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args);
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
if (grid.x * extra_blocks < INT32_MAX) {
grid.x *= extra_blocks;
} else if (grid.y * extra_blocks < 65536) {
grid.y *= extra_blocks;
} else {
throw std::runtime_error(
"[col_reduce_looped] Need to factorize reduction_stride");
}
dim3 grid = output_grid_for_col_reduce(out, args, BN);
int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);

View File

@ -2,6 +2,8 @@
#pragma once
#include <numeric>
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
@ -106,19 +108,31 @@ inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
// Initialize out such that it matches in's layout. Basically we keep any
// transpositions as it were and that allows us either to skip finding the
// location of the output that matches the input or simply contiguous read or
// writes.
auto out_strides = in.strides();
for (auto ax : axes) {
for (auto& s : out_strides) {
if (s > in.strides(ax) && in.strides(ax) > 0) {
s /= in.shape(ax);
// Calculate the transpositions applied to in in order to apply them to out.
std::vector<int> axis_order(in.ndim());
std::iota(axis_order.begin(), axis_order.end(), 0);
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
return in.strides(left) > in.strides(right);
});
// Transpose the shape and calculate the strides
Shape out_shape(in.ndim());
Strides out_strides(in.ndim(), 1);
for (int i = 0; i < in.ndim(); i++) {
out_shape[i] = out.shape(axis_order[i]);
}
for (int i = in.ndim() - 2; i >= 0; i--) {
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
}
// Reverse the axis order to get the final strides
Strides final_strides(in.ndim());
for (int i = 0; i < in.ndim(); i++) {
final_strides[axis_order[i]] = out_strides[i];
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
// Calculate the resulting contiguity and do the memory allocation
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
@ -126,7 +140,7 @@ inline void allocate_same_layout(
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
final_strides,
fl,
allocator::free);
}

View File

@ -105,15 +105,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in += start_row * size;
out += start_row;
if (size % N == 0) {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]);
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
}
if (final_offset < size) {
for (int k = 0; k < M; k++) {

View File

@ -12,7 +12,6 @@ cuda_skip = {
"TestLayers.test_upsample",
"TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestUpsample.test_torch_upsample",
# Block masked matmul NYI