diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 553793944..8922000e2 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -21,12 +21,8 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(!axes_.empty()); assert(out.size() != in.size()); - // out.set_data(allocator::malloc(out.nbytes())); - auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - // encoder.set_input_array(in); - // encoder.set_output_array(out); if (in.size() == 0) { throw std::runtime_error("Should never reach here."); @@ -50,11 +46,6 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { return; } - if (plan.type == ContiguousReduce && plan.shape.size() == 1) { - segmented_reduce(encoder, in, out, reduce_type_, axes_, plan); - return; - } - if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { row_reduce(encoder, in, out, reduce_type_, axes_, plan); return; diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index fd5ec256f..557c973f3 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -13,26 +13,6 @@ namespace cu { namespace cg = cooperative_groups; -namespace { - -// TODO: Should make a custom complex type -template -inline __device__ U __cast(T x) { - return static_cast(x); -} - -template <> -inline __device__ bool __cast(cuComplex x) { - return x.x != 0 && x.y != 0; -} - -template <> -inline __device__ cuComplex __cast(bool x) { - return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); -} - -} // namespace - template __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { auto grid = cg::this_grid(); @@ -54,8 +34,8 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); - for (int i = 0; i < N; i++) { - accs[i] = op(accs[i], __cast(vals[i])); + for (int j = 0; j < N; j++) { + accs[j] = op(accs[j], __cast(vals[j])); } } @@ -74,15 +54,17 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { } accs[0] = cg::reduce(warp, accs[0], op); - __shared__ U shared_accumulators[32]; - if (warp.thread_rank() == 0) { - shared_accumulators[warp.meta_group_rank()] = accs[0]; + if (warp.meta_group_size() > 1) { + __shared__ U shared_accumulators[32]; + if (warp.thread_rank() == 0) { + shared_accumulators[warp.meta_group_rank()] = accs[0]; + } + block.sync(); + accs[0] = (warp.thread_rank() < warp.meta_group_size()) + ? shared_accumulators[warp.thread_rank()] + : init; + accs[0] = cg::reduce(warp, accs[0], op); } - block.sync(); - accs[0] = (warp.thread_rank() < warp.meta_group_size()) - ? shared_accumulators[warp.thread_rank()] - : init; - accs[0] = cg::reduce(warp, accs[0], op); if (block.thread_rank() == 0) { out[grid.block_rank()] = accs[0]; @@ -96,7 +78,7 @@ void all_reduce( const array& in, array& out, Reduce::ReduceType reduce_type) { - constexpr int N_READS = 4; + constexpr int N_READS = 8; out.set_data(allocator::malloc(out.nbytes())); @@ -118,14 +100,13 @@ void all_reduce( } size_t reductions_per_block = std::max( static_cast(threads), (reductions + blocks - 1) / blocks); - size_t block_step = reductions_per_block * N_READS; + size_t block_step = reductions_per_block * N; return std::make_tuple(blocks, threads, block_step); }; int blocks, threads; size_t block_step; - bool large = in.size() > N_READS * 1024; array x = in; // Large array so allocate an intermediate and accumulate there diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index c19a801de..b40d2bd4e 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -3,49 +3,10 @@ #pragma once #include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_utils.cuh" namespace mlx::core::cu { -template -struct uint_by_size; -template <> -struct uint_by_size<2> { - using type = uint16_t; -}; -template <> -struct uint_by_size<4> { - using type = uint32_t; -}; -template <> -struct uint_by_size<8> { - using type = unsigned long long int; -}; - -template -__device__ void atomic_reduce(T* x, T y) { - if constexpr (sizeof(T) == 1) { - using U = uint16_t; - U* x_int = (U*)((char*)x - ((size_t)x % 2)); - int shift = ((char*)x - (char*)x_int) * 8; - int mask = 0xff << shift; - U old_val, new_val; - do { - old_val = *x_int; - T result = Op{}(static_cast((old_val >> shift) & 0xff), y); - new_val = (old_val & ~mask) | (result << shift); - } while (atomicCAS(x_int, old_val, new_val) != old_val); - } else { - using U = typename uint_by_size::type; - U* x_int = (U*)(x); - U old_val, new_val; - do { - old_val = *x_int; - T result = Op{}(*((T*)&old_val), y); - new_val = *((U*)&result); - } while (atomicCAS(x_int, old_val, new_val) != old_val); - } -} - // Reduce ops. struct And { __device__ __forceinline__ bool operator()(bool a, bool b) { diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh new file mode 100644 index 000000000..2d83bd366 --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -0,0 +1,65 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/utils.cuh" + +namespace mlx::core::cu { + +template +struct uint_by_size; +template <> +struct uint_by_size<2> { + using type = uint16_t; +}; +template <> +struct uint_by_size<4> { + using type = uint32_t; +}; +template <> +struct uint_by_size<8> { + using type = unsigned long long int; +}; + +template +__device__ void atomic_reduce(T* x, T y) { + if constexpr (sizeof(T) == 1) { + using U = uint16_t; + U* x_int = (U*)((char*)x - ((size_t)x % 2)); + int shift = ((char*)x - (char*)x_int) * 8; + int mask = 0xff << shift; + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(static_cast((old_val >> shift) & 0xff), y); + new_val = (old_val & ~mask) | (result << shift); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } else { + using U = typename uint_by_size::type; + U* x_int = (U*)(x); + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(*((T*)&old_val), y); + new_val = *((U*)&result); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } +} + +// TODO: Should make a custom complex type +template +inline __device__ U __cast(T x) { + return static_cast(x); +} + +template <> +inline __device__ bool __cast(cuComplex x) { + return x.x != 0 && x.y != 0; +} + +template <> +inline __device__ cuComplex __cast(bool x) { + return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index ae54a27d6..213456692 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -188,8 +188,152 @@ __global__ void row_reduce_looped( } } +template +__global__ void +row_reduce_per_threadblock(T* in, U* out, size_t n_rows, int size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[M][N]; + U accs[M]; + for (int i = 0; i < M; i++) { + accs[i] = init; + } + + const size_t start_row = + min(n_rows - M, static_cast(grid.block_rank() * M)); + in += start_row * size; + out += start_row; + + int i = 0; + for (; i + block.size() * N <= size; i += block.size() * N) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlockedVectorized( + block.thread_rank(), in + k * size + i, vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } + } + + if (size > i) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + i, + vals[k], + size, + __cast(init)); + for (int j = 0; i < N; i++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } + } + + for (int i = 0; i < M; i++) { + accs[i] = cg::reduce(warp, accs[i], op); + } + + if (warp.meta_group_size() > 1) { + __shared__ U shared_accumulators[32 * M]; + if (warp.thread_rank() == 0) { + for (int i = 0; i < M; i++) { + shared_accumulators[warp.meta_group_rank() * M + i] = accs[i]; + } + } + block.sync(); + if (warp.thread_rank() < warp.meta_group_size()) { + for (int i = 0; i < M; i++) { + accs[i] = shared_accumulators[warp.thread_rank() * M + i]; + } + } else { + for (int i = 0; i < M; i++) { + accs[i] = init; + } + } + for (int i = 0; i < M; i++) { + accs[i] = cg::reduce(warp, accs[i], op); + } + } + + if (block.thread_rank() == 0) { + if (grid.block_rank() * M + M <= n_rows) { + for (int i = 0; i < M; i++) { + out[i] = accs[i]; + } + } else { + short offset = grid.block_rank() * M + M - n_rows; + for (int i = offset; i < M; i++) { + out[i] = accs[i]; + } + } + } +} + } // namespace cu +void row_reduce_simple( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + constexpr int N_READS = 8; + + // Initialize out such that its strides match in's layout (except the fastest + // moving axis) + auto [_, out_strides] = shapes_without_reduction_axes(in, axes); + for (auto& s : out_strides) { + s /= plan.shape.back(); + } + auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides); + auto fl = in.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = data_size == out.size(); + out.set_data( + allocator::malloc(out.nbytes()), + data_size, + out_strides, + fl, + allocator::free); + + // Just a way to get out of the constness because cub doesn't like it ... + // (sigh) + array x = in; + + // TODO: If out.size() < 1024 which will be a common case then write this in + // 2 passes. Something like 32 * out.size() and then do a warp reduce. + encoder.set_input_array(x); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Calculate the grid and block dims + size_t reductions = plan.shape.back() / N_READS; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + int threads = std::min(1024UL, reductions); + dim3 block(threads, 1, 1); + auto kernel = cu::row_reduce_per_threadblock; + if (grid.x >= 1024) { + grid.x = (grid.x + 1) / 2; + kernel = cu::row_reduce_per_threadblock; + } + kernel<<>>( + x.data(), out.data(), out.size(), plan.shape.back()); + }); + }); + }); +} + void row_reduce( cu::CommandEncoder& encoder, const array& in, @@ -197,54 +341,58 @@ void row_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { - cu::RowReduceArgs args(in, plan, axes); + if (plan.shape.size() == 1) { + row_reduce_simple(encoder, in, out, reduce_type, axes, plan); + } + // cu::RowReduceArgs args(in, plan, axes); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - using InType = cuda_type_t; - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using OutType = cu::ReduceResult::type; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - constexpr size_t N_READS = 4; - dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims, num_blocks; - auto kernel = - cu::row_reduce_small; - if (args.row_size <= 64) { - if ((args.non_row_reductions < 32 && args.row_size <= 8) || - (args.non_row_reductions <= 8)) { - block_dims.x = std::min(out_dims.x, 1024u); - num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); - num_blocks.y = out_dims.y; - } else { - block_dims.x = WARP_SIZE; - num_blocks.y = out_dims.x; - num_blocks.z = out_dims.y; - kernel = - cu::row_reduce_small_warp; - } - } else { - size_t num_threads = cuda::ceil_div(args.row_size, N_READS); - num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; - MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { - num_blocks.y = out_dims.x; - num_blocks.z = out_dims.y; - block_dims.x = BLOCK_DIM_X; - kernel = cu::row_reduce_looped< - InType, - OutType, - OP, - NDIM, - BLOCK_DIM_X, - N_READS>; - }); - } - kernel<<>>( - in.data(), out.data(), out.size(), args); - }); - }); - }); - }); + // encoder.launch_kernel([&](cudaStream_t stream) { + // MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + // using InType = cuda_type_t; + // MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + // using OutType = cu::ReduceResult::type; + // MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + // constexpr size_t N_READS = 4; + // dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); + // dim3 block_dims, num_blocks; + // auto kernel = + // cu::row_reduce_small; + // if (args.row_size <= 64) { + // if ((args.non_row_reductions < 32 && args.row_size <= 8) || + // (args.non_row_reductions <= 8)) { + // block_dims.x = std::min(out_dims.x, 1024u); + // num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); + // num_blocks.y = out_dims.y; + // } else { + // block_dims.x = WARP_SIZE; + // num_blocks.y = out_dims.x; + // num_blocks.z = out_dims.y; + // kernel = + // cu::row_reduce_small_warp; + // } + // } else { + // size_t num_threads = cuda::ceil_div(args.row_size, N_READS); + // num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; + // MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { + // num_blocks.y = out_dims.x; + // num_blocks.z = out_dims.y; + // block_dims.x = BLOCK_DIM_X; + // kernel = cu::row_reduce_looped< + // InType, + // OutType, + // OP, + // NDIM, + // BLOCK_DIM_X, + // N_READS>; + // }); + // } + // kernel<<>>( + // in.data(), out.data(), out.size(), args); + // }); + // }); + // }); + // }); } } // namespace mlx::core