This commit is contained in:
Angelos Katharopoulos
2025-06-26 22:54:09 -07:00
parent d999675cb9
commit bc60a31cae
5 changed files with 49 additions and 39 deletions

View File

@@ -96,43 +96,52 @@ void all_reduce(
int blocks, threads; int blocks, threads;
size_t block_step; size_t block_step;
array x = in; size_t insize = in.size();
Dtype dt = in.dtype();
// Cub doesn't like const pointers for load (sigh).
void* indata = const_cast<void*>(in.data<void>());
// Large array so allocate an intermediate and accumulate there // Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(in);
if (blocks > 1) { if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {}); array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes())); intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);
encoder.set_input_array(x);
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type; using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>( kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), intermediate.data<U>(), block_step, x.size()); static_cast<T*>(indata),
intermediate.data<U>(),
block_step,
insize);
}); });
}); });
}); });
// Set the input for the next step and recalculate the blocks // Set the input for the next step and recalculate the blocks
x = intermediate; indata = intermediate.data<void>();
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); dt = intermediate.dtype();
insize = intermediate.size();
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(intermediate);
} }
encoder.set_input_array(x);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type; using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>( kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), out.data<U>(), block_step, x.size()); static_cast<T*>(indata), out.data<U>(), block_step, insize);
}); });
}); });
}); });

View File

@@ -38,12 +38,18 @@ struct ColReduceArgs {
const array& in, const array& in,
const ReductionPlan& plan, const ReductionPlan& plan,
const std::vector<int>& axes) { const std::vector<int>& axes) {
using ShapeVector = decltype(plan.shape);
using StridesVector = decltype(plan.strides);
ShapeVector shape_vec;
StridesVector strides_vec;
assert(!plan.shape.empty()); assert(!plan.shape.empty());
reduction_size = plan.shape.back(); reduction_size = plan.shape.back();
reduction_stride = plan.strides.back(); reduction_stride = plan.strides.back();
int64_t stride_back = 1; int64_t stride_back = 1;
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
while (!shape_vec.empty() && stride_back < reduction_stride) { while (!shape_vec.empty() && stride_back < reduction_stride) {
stride_back *= shape_vec.back(); stride_back *= shape_vec.back();
shape_vec.pop_back(); shape_vec.pop_back();
@@ -54,8 +60,8 @@ struct ColReduceArgs {
std::sort(indices.begin(), indices.end(), [&](int left, int right) { std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right]; return strides_vec[left] > strides_vec[right];
}); });
decltype(shape_vec) sorted_shape; ShapeVector sorted_shape;
decltype(strides_vec) sorted_strides; StridesVector sorted_strides;
for (auto idx : indices) { for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]); sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]); sorted_strides.push_back(strides_vec[idx]);
@@ -206,26 +212,25 @@ void col_reduce_looped(
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ... encoder.set_input_array(in);
// (sigh)
array x = in;
encoder.set_input_array(x);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type; using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 4; constexpr int N_READS = 4;
constexpr int BM = 32; constexpr int BM = 32;
constexpr int BN = 32; constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN); dim3 grid = output_grid_for_col_reduce(out, args, BN);
int blocks = BM * BN / N_READS; int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>; auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args); kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
}); });
}); });
}); });
@@ -247,8 +252,7 @@ void col_reduce(
// a subrow of the fast moving axis. For instance 32 elements. // a subrow of the fast moving axis. For instance 32 elements.
// //
// Notes: As in row reduce we opt to read as much in order as possible and // Notes: As in row reduce we opt to read as much in order as possible and
// leave // leave transpositions as they are (contrary to our Metal backend).
// transpositions as they are (contrary to our Metal backend).
// //
// Moreover we need different kernels for short rows and tuning // Moreover we need different kernels for short rows and tuning

View File

@@ -31,7 +31,6 @@ void init_reduce(
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {

View File

@@ -146,7 +146,7 @@ inline void allocate_same_layout(
auto fl = in.flags(); auto fl = in.flags();
fl.row_contiguous = rc; fl.row_contiguous = rc;
fl.col_contiguous = cc; fl.col_contiguous = cc;
fl.contiguous = data_size == out.size(); fl.contiguous = true;
out.set_data( out.set_data(
allocator::malloc(out.nbytes()), allocator::malloc(out.nbytes()),
data_size, data_size,

View File

@@ -241,20 +241,19 @@ void row_reduce_simple(
// kernel. // kernel.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
// TODO: If out.size() < 1024 which will be a common case then write this in // TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce. // 2 passes. Something like 32 * out.size() and then do a warp reduce.
encoder.set_input_array(x); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type; using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims // Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
@@ -271,7 +270,7 @@ void row_reduce_simple(
// Launch // Launch
kernel<<<grid, block, 0, stream>>>( kernel<<<grid, block, 0, stream>>>(
x.data<T>(), out.data<U>(), out.size(), plan.shape.back()); indata, out.data<U>(), out.size(), plan.shape.back());
}); });
}); });
}); });
@@ -291,20 +290,19 @@ void row_reduce_looped(
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ... encoder.set_input_array(in);
// (sigh)
array x = in;
encoder.set_input_array(x);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type; using U = cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims // Calculate the grid and block dims
args.sort_access_pattern(x, axes); args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS; size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions); int threads = std::min(1024UL, reductions);
@@ -322,7 +320,7 @@ void row_reduce_looped(
// Launch // Launch
kernel<<<grid, block, 0, stream>>>( kernel<<<grid, block, 0, stream>>>(
x.data<T>(), out.data<U>(), out.size(), args); indata, out.data<U>(), out.size(), args);
}); });
}); });
}); });