Fixes for transpositions and expands

This commit is contained in:
Angelos Katharopoulos 2025-06-23 05:49:49 -07:00
parent fd1d0821d2
commit 8bd4bf2393
5 changed files with 54 additions and 26 deletions

View File

@ -34,7 +34,19 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// If it is a general reduce then copy the input to a contiguous array and
// recompute the plan.
if (plan.type == GeneralReduce) {
//
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
// like we do in Metal. When it comes to broadcasted reduction axes
// some can be ignored eg for min/max.
bool broadcasted = false;
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
broadcasted = in.strides(i) == 0;
}
}
if (plan.type == GeneralReduce || broadcasted) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy);

View File

@ -104,13 +104,24 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
if (tile_x * BN + BN <= args.reduction_stride) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
if (args.reduction_stride % N_READS == 0) {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
@ -157,11 +168,13 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args) {
auto out_shape = out.shape();
auto out_strides = out.strides();
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
out_shape.pop_back();
out_strides.pop_back();
Shape out_shape;
Strides out_strides;
for (int i = 0; i < out.ndim(); i++) {
if (out.strides(i) >= args.reduction_stride) {
out_shape.push_back(out.shape(i));
out_strides.push_back(out.strides(i));
}
}
return get_2d_grid_dims(out_shape, out_strides);
}

View File

@ -113,7 +113,7 @@ inline void allocate_same_layout(
auto out_strides = in.strides();
for (auto ax : axes) {
for (auto& s : out_strides) {
if (s > in.strides(ax)) {
if (s > in.strides(ax) && in.strides(ax) > 0) {
s /= in.shape(ax);
}
}

View File

@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc.
#include <numeric>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
@ -57,20 +59,24 @@ struct RowReduceArgs {
}
// Convert shape and strides as if in was contiguous
void convert_shapes_to_contiguous(
const array& in,
const std::vector<int>& axes) {
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
auto shape_vec = in.shape();
auto strides_vec = in.strides();
size_t s = 1;
for (int i = in.ndim() - 1; i >= 0; i--) {
strides_vec[i] = s;
s *= shape_vec[i];
}
std::tie(shape_vec, strides_vec) =
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
std::vector<int> indices(shape_vec.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
return strides_vec[left] > strides_vec[right];
});
decltype(shape_vec) sorted_shape;
decltype(strides_vec) sorted_strides;
for (auto idx : indices) {
sorted_shape.push_back(shape_vec[idx]);
sorted_strides.push_back(strides_vec[idx]);
}
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
collapse_contiguous_dims(sorted_shape, sorted_strides);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
@ -282,7 +288,7 @@ void row_reduce_looped(
using U = cu::ReduceResult<OP, T>::type;
// Calculate the grid and block dims
args.convert_shapes_to_contiguous(x, axes);
args.sort_access_pattern(x, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);

View File

@ -1,7 +1,6 @@
cuda_skip = {
"TestArray.test_api",
"TestBF16.test_arg_reduction_ops",
"TestBF16.test_reduction_ops",
"TestBlas.test_complex_gemm",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
@ -15,8 +14,6 @@ cuda_skip = {
"TestOps.test_dynamic_slicing",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
# Block masked matmul NYI
"TestBlas.test_block_masked_matmul",