More fixes for all reductions

This commit is contained in:
Angelos Katharopoulos 2025-06-24 01:17:30 -07:00
parent 8bd4bf2393
commit a57a75b992
6 changed files with 93 additions and 57 deletions

View File

@ -46,7 +46,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
broadcasted = in.strides(i) == 0; broadcasted = in.strides(i) == 0;
} }
} }
if (plan.type == GeneralReduce || broadcasted) { if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {}); array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s); copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy); encoder.add_temporary(in_copy);

View File

@ -33,18 +33,17 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
size_t end = start + block_step; size_t end = start + block_step;
size_t check = min(end, size); size_t check = min(end, size);
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) { size_t i = start;
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals); cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], __cast<U, T>(vals[j])); accs[0] = op(accs[0], __cast<U, T>(vals[j]));
} }
} }
if (end > size) { if (i < check) {
size_t offset = end - block.size() * N;
int block_end = size - offset;
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init)); block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], __cast<U, T>(vals[i])); accs[0] = op(accs[0], __cast<U, T>(vals[i]));
} }
@ -70,24 +69,27 @@ void all_reduce(
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto get_args = [](size_t size, int N) { auto get_args = [](size_t size, int N) {
size_t reductions = size / N; int threads = std::min(512UL, (size + N - 1) / N);
int threads = 512; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
size_t full_blocks = (reductions + threads - 1) / threads; int reductions_per_step = threads * N;
size_t steps_needed =
(size + reductions_per_step - 1) / reductions_per_step;
int blocks; int blocks;
if (full_blocks < 32) { if (steps_needed < 32) {
blocks = 1; blocks = 1;
} else if (full_blocks < 128) { } else if (steps_needed < 128) {
blocks = 32; blocks = 32;
} else if (full_blocks < 512) { } else if (steps_needed < 512) {
blocks = 128; blocks = 128;
} else if (full_blocks < 1024) { } else if (steps_needed < 1024) {
blocks = 512; blocks = 512;
} else { } else {
blocks = 1024; blocks = 1024;
} }
size_t reductions_per_block = std::max(
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks); size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
size_t block_step = reductions_per_block * N; size_t block_step = steps_per_block * reductions_per_step;
return std::make_tuple(blocks, threads, block_step); return std::make_tuple(blocks, threads, block_step);
}; };
@ -99,7 +101,6 @@ void all_reduce(
// Large array so allocate an intermediate and accumulate there // Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
if (blocks > 1) { if (blocks > 1) {
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
array intermediate({blocks}, out.dtype(), nullptr, {}); array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes())); intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);

View File

@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
@ -47,8 +49,19 @@ struct ColReduceArgs {
shape_vec.pop_back(); shape_vec.pop_back();
strides_vec.pop_back(); strides_vec.pop_back();
} }
std::vector<int> indices(shape_vec.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
decltype(shape_vec) sorted_shape;
decltype(strides_vec) sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
}
std::tie(shape_vec, strides_vec) = std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec); collapse_contiguous_dims(sorted_shape, sorted_strides);
shape = const_param(shape_vec); shape = const_param(shape_vec);
strides = const_param(strides_vec); strides = const_param(strides_vec);
ndim = shape_vec.size(); ndim = shape_vec.size();
@ -167,16 +180,18 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
const array& out, const array& out,
const cu::ColReduceArgs& args) { const cu::ColReduceArgs& args,
Shape out_shape; int bn) {
Strides out_strides; int gx, gy = 1;
for (int i = 0; i < out.ndim(); i++) { size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
if (out.strides(i) >= args.reduction_stride) { size_t n_outer_blocks = out.size() / args.reduction_stride;
out_shape.push_back(out.shape(i)); size_t n_blocks = n_outer_blocks * n_inner_blocks;
out_strides.push_back(out.strides(i)); while (n_blocks / gy > INT32_MAX) {
} gy *= 2;
} }
return get_2d_grid_dims(out_shape, out_strides); gx = cuda::ceil_div(n_blocks, gy);
return dim3(gx, gy, 1);
} }
void col_reduce_looped( void col_reduce_looped(
@ -207,16 +222,7 @@ void col_reduce_looped(
constexpr int N_READS = 4; constexpr int N_READS = 4;
constexpr int BM = 32; constexpr int BM = 32;
constexpr int BN = 32; constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args); dim3 grid = output_grid_for_col_reduce(out, args, BN);
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
if (grid.x * extra_blocks < INT32_MAX) {
grid.x *= extra_blocks;
} else if (grid.y * extra_blocks < 65536) {
grid.y *= extra_blocks;
} else {
throw std::runtime_error(
"[col_reduce_looped] Need to factorize reduction_stride");
}
int blocks = BM * BN / N_READS; int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>; auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args); kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);

View File

@ -2,6 +2,8 @@
#pragma once #pragma once
#include <numeric>
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@ -106,19 +108,31 @@ inline void allocate_same_layout(
array& out, array& out,
const array& in, const array& in,
const std::vector<int>& axes) { const std::vector<int>& axes) {
// Initialize out such that it matches in's layout. Basically we keep any // Calculate the transpositions applied to in in order to apply them to out.
// transpositions as it were and that allows us either to skip finding the std::vector<int> axis_order(in.ndim());
// location of the output that matches the input or simply contiguous read or std::iota(axis_order.begin(), axis_order.end(), 0);
// writes. std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
auto out_strides = in.strides(); return in.strides(left) > in.strides(right);
for (auto ax : axes) { });
for (auto& s : out_strides) {
if (s > in.strides(ax) && in.strides(ax) > 0) { // Transpose the shape and calculate the strides
s /= in.shape(ax); Shape out_shape(in.ndim());
} Strides out_strides(in.ndim(), 1);
} for (int i = 0; i < in.ndim(); i++) {
out_shape[i] = out.shape(axis_order[i]);
} }
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides); for (int i = in.ndim() - 2; i >= 0; i--) {
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
}
// Reverse the axis order to get the final strides
Strides final_strides(in.ndim());
for (int i = 0; i < in.ndim(); i++) {
final_strides[axis_order[i]] = out_strides[i];
}
// Calculate the resulting contiguity and do the memory allocation
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
auto fl = in.flags(); auto fl = in.flags();
fl.row_contiguous = rc; fl.row_contiguous = rc;
fl.col_contiguous = cc; fl.col_contiguous = cc;
@ -126,7 +140,7 @@ inline void allocate_same_layout(
out.set_data( out.set_data(
allocator::malloc(out.nbytes()), allocator::malloc(out.nbytes()),
data_size, data_size,
out_strides, final_strides,
fl, fl,
allocator::free); allocator::free);
} }

View File

@ -105,12 +105,28 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in += start_row * size; in += start_row * size;
out += start_row; out += start_row;
for (size_t r = 0; r < full_blocks; r++) { if (size % N == 0) {
for (int k = 0; k < M; k++) { for (size_t r = 0; r < full_blocks; r++) {
cub::LoadDirectBlockedVectorized<T, N>( for (int k = 0; k < M; k++) {
block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]); cub::LoadDirectBlockedVectorized<T, N>(
for (int j = 0; j < N; j++) { block.thread_rank(),
accs[k] = op(accs[k], __cast<U, T>(vals[k][j])); in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
} }
} }
} }

View File

@ -12,7 +12,6 @@ cuda_skip = {
"TestLayers.test_upsample", "TestLayers.test_upsample",
"TestOps.test_complex_ops", "TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing", "TestOps.test_dynamic_slicing",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes", "TestReduce.test_dtypes",
"TestUpsample.test_torch_upsample", "TestUpsample.test_torch_upsample",
# Block masked matmul NYI # Block masked matmul NYI