mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
More fixes for all reductions
This commit is contained in:
parent
8bd4bf2393
commit
a57a75b992
@ -46,7 +46,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
broadcasted = in.strides(i) == 0;
|
||||
}
|
||||
}
|
||||
if (plan.type == GeneralReduce || broadcasted) {
|
||||
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
|
@ -33,18 +33,17 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
||||
size_t end = start + block_step;
|
||||
size_t check = min(end, size);
|
||||
|
||||
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
|
||||
size_t i = start;
|
||||
for (; i + block.size() * N <= check; i += block.size() * N) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
||||
}
|
||||
}
|
||||
|
||||
if (end > size) {
|
||||
size_t offset = end - block.size() * N;
|
||||
int block_end = size - offset;
|
||||
if (i < check) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
|
||||
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
||||
}
|
||||
@ -70,24 +69,27 @@ void all_reduce(
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto get_args = [](size_t size, int N) {
|
||||
size_t reductions = size / N;
|
||||
int threads = 512;
|
||||
size_t full_blocks = (reductions + threads - 1) / threads;
|
||||
int threads = std::min(512UL, (size + N - 1) / N);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int reductions_per_step = threads * N;
|
||||
size_t steps_needed =
|
||||
(size + reductions_per_step - 1) / reductions_per_step;
|
||||
|
||||
int blocks;
|
||||
if (full_blocks < 32) {
|
||||
if (steps_needed < 32) {
|
||||
blocks = 1;
|
||||
} else if (full_blocks < 128) {
|
||||
} else if (steps_needed < 128) {
|
||||
blocks = 32;
|
||||
} else if (full_blocks < 512) {
|
||||
} else if (steps_needed < 512) {
|
||||
blocks = 128;
|
||||
} else if (full_blocks < 1024) {
|
||||
} else if (steps_needed < 1024) {
|
||||
blocks = 512;
|
||||
} else {
|
||||
blocks = 1024;
|
||||
}
|
||||
size_t reductions_per_block = std::max(
|
||||
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
||||
size_t block_step = reductions_per_block * N;
|
||||
|
||||
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
|
||||
size_t block_step = steps_per_block * reductions_per_step;
|
||||
|
||||
return std::make_tuple(blocks, threads, block_step);
|
||||
};
|
||||
@ -99,7 +101,6 @@ void all_reduce(
|
||||
// Large array so allocate an intermediate and accumulate there
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
if (blocks > 1) {
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
encoder.add_temporary(intermediate);
|
||||
|
@ -1,5 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
@ -47,8 +49,19 @@ struct ColReduceArgs {
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::vector<int> indices(shape_vec.size());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||
return strides_vec[left] > strides_vec[right];
|
||||
});
|
||||
decltype(shape_vec) sorted_shape;
|
||||
decltype(strides_vec) sorted_strides;
|
||||
for (auto idx : indices) {
|
||||
sorted_shape.push_back(shape_vec[idx]);
|
||||
sorted_strides.push_back(strides_vec[idx]);
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
@ -167,16 +180,18 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args) {
|
||||
Shape out_shape;
|
||||
Strides out_strides;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
if (out.strides(i) >= args.reduction_stride) {
|
||||
out_shape.push_back(out.shape(i));
|
||||
out_strides.push_back(out.strides(i));
|
||||
const cu::ColReduceArgs& args,
|
||||
int bn) {
|
||||
int gx, gy = 1;
|
||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
||||
while (n_blocks / gy > INT32_MAX) {
|
||||
gy *= 2;
|
||||
}
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
gx = cuda::ceil_div(n_blocks, gy);
|
||||
|
||||
return dim3(gx, gy, 1);
|
||||
}
|
||||
|
||||
void col_reduce_looped(
|
||||
@ -207,16 +222,7 @@ void col_reduce_looped(
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, args);
|
||||
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
|
||||
if (grid.x * extra_blocks < INT32_MAX) {
|
||||
grid.x *= extra_blocks;
|
||||
} else if (grid.y * extra_blocks < 65536) {
|
||||
grid.y *= extra_blocks;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[col_reduce_looped] Need to factorize reduction_stride");
|
||||
}
|
||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
||||
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
@ -106,19 +108,31 @@ inline void allocate_same_layout(
|
||||
array& out,
|
||||
const array& in,
|
||||
const std::vector<int>& axes) {
|
||||
// Initialize out such that it matches in's layout. Basically we keep any
|
||||
// transpositions as it were and that allows us either to skip finding the
|
||||
// location of the output that matches the input or simply contiguous read or
|
||||
// writes.
|
||||
auto out_strides = in.strides();
|
||||
for (auto ax : axes) {
|
||||
for (auto& s : out_strides) {
|
||||
if (s > in.strides(ax) && in.strides(ax) > 0) {
|
||||
s /= in.shape(ax);
|
||||
// Calculate the transpositions applied to in in order to apply them to out.
|
||||
std::vector<int> axis_order(in.ndim());
|
||||
std::iota(axis_order.begin(), axis_order.end(), 0);
|
||||
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
|
||||
return in.strides(left) > in.strides(right);
|
||||
});
|
||||
|
||||
// Transpose the shape and calculate the strides
|
||||
Shape out_shape(in.ndim());
|
||||
Strides out_strides(in.ndim(), 1);
|
||||
for (int i = 0; i < in.ndim(); i++) {
|
||||
out_shape[i] = out.shape(axis_order[i]);
|
||||
}
|
||||
for (int i = in.ndim() - 2; i >= 0; i--) {
|
||||
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
|
||||
}
|
||||
|
||||
// Reverse the axis order to get the final strides
|
||||
Strides final_strides(in.ndim());
|
||||
for (int i = 0; i < in.ndim(); i++) {
|
||||
final_strides[axis_order[i]] = out_strides[i];
|
||||
}
|
||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
|
||||
// Calculate the resulting contiguity and do the memory allocation
|
||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
|
||||
auto fl = in.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
@ -126,7 +140,7 @@ inline void allocate_same_layout(
|
||||
out.set_data(
|
||||
allocator::malloc(out.nbytes()),
|
||||
data_size,
|
||||
out_strides,
|
||||
final_strides,
|
||||
fl,
|
||||
allocator::free);
|
||||
}
|
||||
|
@ -105,15 +105,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
in += start_row * size;
|
||||
out += start_row;
|
||||
|
||||
if (size % N == 0) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(
|
||||
block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]);
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (final_offset < size) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
|
@ -12,7 +12,6 @@ cuda_skip = {
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
# Block masked matmul NYI
|
||||
|
Loading…
Reference in New Issue
Block a user