mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
More fixes for all reductions
This commit is contained in:
parent
8bd4bf2393
commit
a57a75b992
@ -46,7 +46,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
broadcasted = in.strides(i) == 0;
|
broadcasted = in.strides(i) == 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (plan.type == GeneralReduce || broadcasted) {
|
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
encoder.add_temporary(in_copy);
|
encoder.add_temporary(in_copy);
|
||||||
|
@ -33,18 +33,17 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
|||||||
size_t end = start + block_step;
|
size_t end = start + block_step;
|
||||||
size_t check = min(end, size);
|
size_t check = min(end, size);
|
||||||
|
|
||||||
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
|
size_t i = start;
|
||||||
|
for (; i + block.size() * N <= check; i += block.size() * N) {
|
||||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||||
for (int j = 0; j < N; j++) {
|
for (int j = 0; j < N; j++) {
|
||||||
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (end > size) {
|
if (i < check) {
|
||||||
size_t offset = end - block.size() * N;
|
|
||||||
int block_end = size - offset;
|
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
|
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
||||||
}
|
}
|
||||||
@ -70,24 +69,27 @@ void all_reduce(
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto get_args = [](size_t size, int N) {
|
auto get_args = [](size_t size, int N) {
|
||||||
size_t reductions = size / N;
|
int threads = std::min(512UL, (size + N - 1) / N);
|
||||||
int threads = 512;
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
size_t full_blocks = (reductions + threads - 1) / threads;
|
int reductions_per_step = threads * N;
|
||||||
|
size_t steps_needed =
|
||||||
|
(size + reductions_per_step - 1) / reductions_per_step;
|
||||||
|
|
||||||
int blocks;
|
int blocks;
|
||||||
if (full_blocks < 32) {
|
if (steps_needed < 32) {
|
||||||
blocks = 1;
|
blocks = 1;
|
||||||
} else if (full_blocks < 128) {
|
} else if (steps_needed < 128) {
|
||||||
blocks = 32;
|
blocks = 32;
|
||||||
} else if (full_blocks < 512) {
|
} else if (steps_needed < 512) {
|
||||||
blocks = 128;
|
blocks = 128;
|
||||||
} else if (full_blocks < 1024) {
|
} else if (steps_needed < 1024) {
|
||||||
blocks = 512;
|
blocks = 512;
|
||||||
} else {
|
} else {
|
||||||
blocks = 1024;
|
blocks = 1024;
|
||||||
}
|
}
|
||||||
size_t reductions_per_block = std::max(
|
|
||||||
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
|
||||||
size_t block_step = reductions_per_block * N;
|
size_t block_step = steps_per_block * reductions_per_step;
|
||||||
|
|
||||||
return std::make_tuple(blocks, threads, block_step);
|
return std::make_tuple(blocks, threads, block_step);
|
||||||
};
|
};
|
||||||
@ -99,7 +101,6 @@ void all_reduce(
|
|||||||
// Large array so allocate an intermediate and accumulate there
|
// Large array so allocate an intermediate and accumulate there
|
||||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||||
if (blocks > 1) {
|
if (blocks > 1) {
|
||||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
|
||||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
encoder.add_temporary(intermediate);
|
encoder.add_temporary(intermediate);
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
@ -47,8 +49,19 @@ struct ColReduceArgs {
|
|||||||
shape_vec.pop_back();
|
shape_vec.pop_back();
|
||||||
strides_vec.pop_back();
|
strides_vec.pop_back();
|
||||||
}
|
}
|
||||||
|
std::vector<int> indices(shape_vec.size());
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||||
|
return strides_vec[left] > strides_vec[right];
|
||||||
|
});
|
||||||
|
decltype(shape_vec) sorted_shape;
|
||||||
|
decltype(strides_vec) sorted_strides;
|
||||||
|
for (auto idx : indices) {
|
||||||
|
sorted_shape.push_back(shape_vec[idx]);
|
||||||
|
sorted_strides.push_back(strides_vec[idx]);
|
||||||
|
}
|
||||||
std::tie(shape_vec, strides_vec) =
|
std::tie(shape_vec, strides_vec) =
|
||||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||||
shape = const_param(shape_vec);
|
shape = const_param(shape_vec);
|
||||||
strides = const_param(strides_vec);
|
strides = const_param(strides_vec);
|
||||||
ndim = shape_vec.size();
|
ndim = shape_vec.size();
|
||||||
@ -167,16 +180,18 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
|
|
||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args) {
|
const cu::ColReduceArgs& args,
|
||||||
Shape out_shape;
|
int bn) {
|
||||||
Strides out_strides;
|
int gx, gy = 1;
|
||||||
for (int i = 0; i < out.ndim(); i++) {
|
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||||
if (out.strides(i) >= args.reduction_stride) {
|
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||||
out_shape.push_back(out.shape(i));
|
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
||||||
out_strides.push_back(out.strides(i));
|
while (n_blocks / gy > INT32_MAX) {
|
||||||
}
|
gy *= 2;
|
||||||
}
|
}
|
||||||
return get_2d_grid_dims(out_shape, out_strides);
|
gx = cuda::ceil_div(n_blocks, gy);
|
||||||
|
|
||||||
|
return dim3(gx, gy, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void col_reduce_looped(
|
void col_reduce_looped(
|
||||||
@ -207,16 +222,7 @@ void col_reduce_looped(
|
|||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
constexpr int BM = 32;
|
constexpr int BM = 32;
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
dim3 grid = output_grid_for_col_reduce(out, args);
|
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||||
size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN);
|
|
||||||
if (grid.x * extra_blocks < INT32_MAX) {
|
|
||||||
grid.x *= extra_blocks;
|
|
||||||
} else if (grid.y * extra_blocks < 65536) {
|
|
||||||
grid.y *= extra_blocks;
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[col_reduce_looped] Need to factorize reduction_stride");
|
|
||||||
}
|
|
||||||
int blocks = BM * BN / N_READS;
|
int blocks = BM * BN / N_READS;
|
||||||
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
||||||
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
|
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
@ -106,19 +108,31 @@ inline void allocate_same_layout(
|
|||||||
array& out,
|
array& out,
|
||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
// Initialize out such that it matches in's layout. Basically we keep any
|
// Calculate the transpositions applied to in in order to apply them to out.
|
||||||
// transpositions as it were and that allows us either to skip finding the
|
std::vector<int> axis_order(in.ndim());
|
||||||
// location of the output that matches the input or simply contiguous read or
|
std::iota(axis_order.begin(), axis_order.end(), 0);
|
||||||
// writes.
|
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
|
||||||
auto out_strides = in.strides();
|
return in.strides(left) > in.strides(right);
|
||||||
for (auto ax : axes) {
|
});
|
||||||
for (auto& s : out_strides) {
|
|
||||||
if (s > in.strides(ax) && in.strides(ax) > 0) {
|
// Transpose the shape and calculate the strides
|
||||||
s /= in.shape(ax);
|
Shape out_shape(in.ndim());
|
||||||
}
|
Strides out_strides(in.ndim(), 1);
|
||||||
}
|
for (int i = 0; i < in.ndim(); i++) {
|
||||||
|
out_shape[i] = out.shape(axis_order[i]);
|
||||||
}
|
}
|
||||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
|
for (int i = in.ndim() - 2; i >= 0; i--) {
|
||||||
|
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse the axis order to get the final strides
|
||||||
|
Strides final_strides(in.ndim());
|
||||||
|
for (int i = 0; i < in.ndim(); i++) {
|
||||||
|
final_strides[axis_order[i]] = out_strides[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the resulting contiguity and do the memory allocation
|
||||||
|
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
|
||||||
auto fl = in.flags();
|
auto fl = in.flags();
|
||||||
fl.row_contiguous = rc;
|
fl.row_contiguous = rc;
|
||||||
fl.col_contiguous = cc;
|
fl.col_contiguous = cc;
|
||||||
@ -126,7 +140,7 @@ inline void allocate_same_layout(
|
|||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(out.nbytes()),
|
allocator::malloc(out.nbytes()),
|
||||||
data_size,
|
data_size,
|
||||||
out_strides,
|
final_strides,
|
||||||
fl,
|
fl,
|
||||||
allocator::free);
|
allocator::free);
|
||||||
}
|
}
|
||||||
|
@ -105,12 +105,28 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
in += start_row * size;
|
in += start_row * size;
|
||||||
out += start_row;
|
out += start_row;
|
||||||
|
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
if (size % N == 0) {
|
||||||
for (int k = 0; k < M; k++) {
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
cub::LoadDirectBlockedVectorized<T, N>(
|
for (int k = 0; k < M; k++) {
|
||||||
block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]);
|
cub::LoadDirectBlockedVectorized<T, N>(
|
||||||
for (int j = 0; j < N; j++) {
|
block.thread_rank(),
|
||||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
in + k * size + r * (block.size() * N),
|
||||||
|
vals[k]);
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
|
for (int k = 0; k < M; k++) {
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(),
|
||||||
|
in + k * size + r * (block.size() * N),
|
||||||
|
vals[k]);
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,6 @@ cuda_skip = {
|
|||||||
"TestLayers.test_upsample",
|
"TestLayers.test_upsample",
|
||||||
"TestOps.test_complex_ops",
|
"TestOps.test_complex_ops",
|
||||||
"TestOps.test_dynamic_slicing",
|
"TestOps.test_dynamic_slicing",
|
||||||
"TestReduce.test_axis_permutation_sums",
|
|
||||||
"TestReduce.test_dtypes",
|
"TestReduce.test_dtypes",
|
||||||
"TestUpsample.test_torch_upsample",
|
"TestUpsample.test_torch_upsample",
|
||||||
# Block masked matmul NYI
|
# Block masked matmul NYI
|
||||||
|
Loading…
Reference in New Issue
Block a user