diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 4e1195c98..8350eebb7 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -46,7 +46,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { broadcasted = in.strides(i) == 0; } } - if (plan.type == GeneralReduce || broadcasted) { + if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { array in_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, in_copy, CopyType::General, s); encoder.add_temporary(in_copy); diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 54dd351bb..0467da104 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -33,18 +33,17 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { size_t end = start + block_step; size_t check = min(end, size); - for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) { + size_t i = start; + for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { accs[0] = op(accs[0], __cast(vals[j])); } } - if (end > size) { - size_t offset = end - block.size() * N; - int block_end = size - offset; + if (i < check) { cub::LoadDirectBlocked( - block.thread_rank(), in + offset, vals, block_end, __cast(init)); + block.thread_rank(), in + i, vals, check - i, __cast(init)); for (int i = 0; i < N; i++) { accs[0] = op(accs[0], __cast(vals[i])); } @@ -70,24 +69,27 @@ void all_reduce( out.set_data(allocator::malloc(out.nbytes())); auto get_args = [](size_t size, int N) { - size_t reductions = size / N; - int threads = 512; - size_t full_blocks = (reductions + threads - 1) / threads; + int threads = std::min(512UL, (size + N - 1) / N); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = + (size + reductions_per_step - 1) / reductions_per_step; + int blocks; - if (full_blocks < 32) { + if (steps_needed < 32) { blocks = 1; - } else if (full_blocks < 128) { + } else if (steps_needed < 128) { blocks = 32; - } else if (full_blocks < 512) { + } else if (steps_needed < 512) { blocks = 128; - } else if (full_blocks < 1024) { + } else if (steps_needed < 1024) { blocks = 512; } else { blocks = 1024; } - size_t reductions_per_block = std::max( - static_cast(threads), (reductions + blocks - 1) / blocks); - size_t block_step = reductions_per_block * N; + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; return std::make_tuple(blocks, threads, block_step); }; @@ -99,7 +101,6 @@ void all_reduce( // Large array so allocate an intermediate and accumulate there std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); if (blocks > 1) { - std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); array intermediate({blocks}, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); encoder.add_temporary(intermediate); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 094e667c5..8dbebe386 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" @@ -47,8 +49,19 @@ struct ColReduceArgs { shape_vec.pop_back(); strides_vec.pop_back(); } + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + decltype(shape_vec) sorted_shape; + decltype(strides_vec) sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(shape_vec, strides_vec); + collapse_contiguous_dims(sorted_shape, sorted_strides); shape = const_param(shape_vec); strides = const_param(strides_vec); ndim = shape_vec.size(); @@ -167,16 +180,18 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { inline auto output_grid_for_col_reduce( const array& out, - const cu::ColReduceArgs& args) { - Shape out_shape; - Strides out_strides; - for (int i = 0; i < out.ndim(); i++) { - if (out.strides(i) >= args.reduction_stride) { - out_shape.push_back(out.shape(i)); - out_strides.push_back(out.strides(i)); - } + const cu::ColReduceArgs& args, + int bn) { + int gx, gy = 1; + size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); + size_t n_outer_blocks = out.size() / args.reduction_stride; + size_t n_blocks = n_outer_blocks * n_inner_blocks; + while (n_blocks / gy > INT32_MAX) { + gy *= 2; } - return get_2d_grid_dims(out_shape, out_strides); + gx = cuda::ceil_div(n_blocks, gy); + + return dim3(gx, gy, 1); } void col_reduce_looped( @@ -207,16 +222,7 @@ void col_reduce_looped( constexpr int N_READS = 4; constexpr int BM = 32; constexpr int BN = 32; - dim3 grid = output_grid_for_col_reduce(out, args); - size_t extra_blocks = cuda::ceil_div(args.reduction_stride, BN); - if (grid.x * extra_blocks < INT32_MAX) { - grid.x *= extra_blocks; - } else if (grid.y * extra_blocks < 65536) { - grid.y *= extra_blocks; - } else { - throw std::runtime_error( - "[col_reduce_looped] Need to factorize reduction_stride"); - } + dim3 grid = output_grid_for_col_reduce(out, args, BN); int blocks = BM * BN / N_READS; auto kernel = cu::col_reduce_looped; kernel<<>>(x.data(), out.data(), args); diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index 057f8286c..0cbccf69e 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/backend/cuda/device/utils.cuh" #include @@ -106,19 +108,31 @@ inline void allocate_same_layout( array& out, const array& in, const std::vector& axes) { - // Initialize out such that it matches in's layout. Basically we keep any - // transpositions as it were and that allows us either to skip finding the - // location of the output that matches the input or simply contiguous read or - // writes. - auto out_strides = in.strides(); - for (auto ax : axes) { - for (auto& s : out_strides) { - if (s > in.strides(ax) && in.strides(ax) > 0) { - s /= in.shape(ax); - } - } + // Calculate the transpositions applied to in in order to apply them to out. + std::vector axis_order(in.ndim()); + std::iota(axis_order.begin(), axis_order.end(), 0); + std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) { + return in.strides(left) > in.strides(right); + }); + + // Transpose the shape and calculate the strides + Shape out_shape(in.ndim()); + Strides out_strides(in.ndim(), 1); + for (int i = 0; i < in.ndim(); i++) { + out_shape[i] = out.shape(axis_order[i]); } - auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides); + for (int i = in.ndim() - 2; i >= 0; i--) { + out_strides[i] = out_shape[i + 1] * out_strides[i + 1]; + } + + // Reverse the axis order to get the final strides + Strides final_strides(in.ndim()); + for (int i = 0; i < in.ndim(); i++) { + final_strides[axis_order[i]] = out_strides[i]; + } + + // Calculate the resulting contiguity and do the memory allocation + auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides); auto fl = in.flags(); fl.row_contiguous = rc; fl.col_contiguous = cc; @@ -126,7 +140,7 @@ inline void allocate_same_layout( out.set_data( allocator::malloc(out.nbytes()), data_size, - out_strides, + final_strides, fl, allocator::free); } diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 368b6a23d..7e155795a 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -105,12 +105,28 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { in += start_row * size; out += start_row; - for (size_t r = 0; r < full_blocks; r++) { - for (int k = 0; k < M; k++) { - cub::LoadDirectBlockedVectorized( - block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]); - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], __cast(vals[k][j])); + if (size % N == 0) { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } + } + } else { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } } } } diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 183521ea7..cba642ca1 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -12,7 +12,6 @@ cuda_skip = { "TestLayers.test_upsample", "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", - "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", "TestUpsample.test_torch_upsample", # Block masked matmul NYI