diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 8936bbf71..4e1195c98 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -34,7 +34,19 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // If it is a general reduce then copy the input to a contiguous array and // recompute the plan. - if (plan.type == GeneralReduce) { + // + // TODO: Instead of copying we can use elem-to-loc to deal with broadcasting + // like we do in Metal. When it comes to broadcasted reduction axes + // some can be ignored eg for min/max. + bool broadcasted = false; + for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + broadcasted = in.strides(i) == 0; + } + } + if (plan.type == GeneralReduce || broadcasted) { array in_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, in_copy, CopyType::General, s); encoder.add_temporary(in_copy); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index bbfed594d..094e667c5 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -104,13 +104,24 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data()); size_t total = args.non_col_reductions * args.reduction_size; if (tile_x * BN + BN <= args.reduction_stride) { - for (size_t r = thread_y; r < total; r += BM) { - T vals[N_READS]; - cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], __cast(vals[i])); + if (args.reduction_stride % N_READS == 0) { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } } else { for (size_t r = thread_y; r < total; r += BM) { @@ -157,11 +168,13 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { inline auto output_grid_for_col_reduce( const array& out, const cu::ColReduceArgs& args) { - auto out_shape = out.shape(); - auto out_strides = out.strides(); - while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { - out_shape.pop_back(); - out_strides.pop_back(); + Shape out_shape; + Strides out_strides; + for (int i = 0; i < out.ndim(); i++) { + if (out.strides(i) >= args.reduction_stride) { + out_shape.push_back(out.shape(i)); + out_strides.push_back(out.strides(i)); + } } return get_2d_grid_dims(out_shape, out_strides); } diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index 5a3d09bf5..057f8286c 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -113,7 +113,7 @@ inline void allocate_same_layout( auto out_strides = in.strides(); for (auto ax : axes) { for (auto& s : out_strides) { - if (s > in.strides(ax)) { + if (s > in.strides(ax) && in.strides(ax) > 0) { s /= in.shape(ax); } } diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 4c005cc52..368b6a23d 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" @@ -57,20 +59,24 @@ struct RowReduceArgs { } // Convert shape and strides as if in was contiguous - void convert_shapes_to_contiguous( - const array& in, - const std::vector& axes) { + void sort_access_pattern(const array& in, const std::vector& axes) { auto shape_vec = in.shape(); auto strides_vec = in.strides(); - size_t s = 1; - for (int i = in.ndim() - 1; i >= 0; i--) { - strides_vec[i] = s; - s *= shape_vec[i]; - } std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(shape_vec, strides_vec, axes); + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + decltype(shape_vec) sorted_shape; + decltype(strides_vec) sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(shape_vec, strides_vec); + collapse_contiguous_dims(sorted_shape, sorted_strides); shape = const_param(shape_vec); strides = const_param(strides_vec); ndim = shape_vec.size(); @@ -282,7 +288,7 @@ void row_reduce_looped( using U = cu::ReduceResult::type; // Calculate the grid and block dims - args.convert_shapes_to_contiguous(x, axes); + args.sort_access_pattern(x, axes); dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); size_t reductions = (args.row_size + N_READS - 1) / N_READS; int threads = std::min(1024UL, reductions); diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 89805d017..183521ea7 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,7 +1,6 @@ cuda_skip = { "TestArray.test_api", "TestBF16.test_arg_reduction_ops", - "TestBF16.test_reduction_ops", "TestBlas.test_complex_gemm", "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", @@ -15,8 +14,6 @@ cuda_skip = { "TestOps.test_dynamic_slicing", "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", - "TestReduce.test_expand_sums", - "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", # Block masked matmul NYI "TestBlas.test_block_masked_matmul",