From 9cf7ef1068c818cc82140eb4f24fe6275b1ccbef Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 17 Jun 2025 23:58:51 -0700 Subject: [PATCH] Add all reduce and atomic updates --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/binary_two.cu | 2 +- mlx/backend/cuda/reduce.cu | 30 ++--- mlx/backend/cuda/reduce/all_reduce.cu | 160 +++++++++++++++++++++++++ mlx/backend/cuda/reduce/reduce.cuh | 6 + mlx/backend/cuda/reduce/reduce_ops.cuh | 82 ++++++++++++- 6 files changed, 259 insertions(+), 22 deletions(-) create mode 100644 mlx/backend/cuda/reduce/all_reduce.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index ad979a13f..6487d6aab 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -29,6 +29,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 3047e39f0..074c947da 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -157,7 +157,7 @@ void binary_op_gpu_inplace( if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { auto kernel = - &cu::binary_g_nd; + cu::binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large); kernel<<>>( diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index a740113db..553793944 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -21,29 +21,15 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(!axes_.empty()); assert(out.size() != in.size()); - out.set_data(allocator::malloc(out.nbytes())); + // out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - encoder.set_input_array(in); - encoder.set_output_array(out); + // encoder.set_input_array(in); + // encoder.set_output_array(out); - // Fill out with init value. if (in.size() == 0) { - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, { - using InType = cuda_type_t; - using OutType = cu::ReduceResult::type; - thrust::fill_n( - cu::thrust_policy(stream), - thrust::device_pointer_cast(out.data()), - out.data_size(), - cu::ReduceInit::value()); - }); - }); - }); - return; + throw std::runtime_error("Should never reach here."); } // Reduce. @@ -59,8 +45,12 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { plan = get_reduction_plan(in, axes_); } - if ((plan.type == ContiguousAllReduce) || - (plan.type == ContiguousReduce && plan.shape.size() == 1)) { + if (plan.type == ContiguousAllReduce) { + all_reduce(encoder, in, out, reduce_type_); + return; + } + + if (plan.type == ContiguousReduce && plan.shape.size() == 1) { segmented_reduce(encoder, in, out, reduce_type_, axes_, plan); return; } diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu new file mode 100644 index 000000000..64a793f14 --- /dev/null +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -0,0 +1,160 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +namespace { + +// TODO: Should make a custom complex type +template +inline __device__ U __cast(T x) { + return static_cast(x); +} + +template <> +inline __device__ bool __cast(cuComplex x) { + return x.x != 0 && x.y != 0; +} + +template <> +inline __device__ cuComplex __cast(bool x) { + return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); +} + +} // namespace + +template +__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[N]; + U accs[N]; + for (int i = 0; i < N; i++) { + accs[i] = init; + } + + size_t start = grid.block_rank() * block_step; + size_t end = start + block_step; + size_t check = min(end, size); + + for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) { + cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); + for (int i = 0; i < N; i++) { + accs[i] = op(accs[i], __cast(vals[i])); + } + } + + if (end > size) { + size_t offset = end - block.size() * N; + int block_end = size - offset; + cub::LoadDirectBlocked( + block.thread_rank(), in + offset, vals, block_end, __cast(init)); + for (int i = 0; i < N; i++) { + accs[i] = op(accs[i], __cast(vals[i])); + } + } + + for (int i = 1; i < N; i++) { + accs[0] = op(accs[0], accs[i]); + } + accs[0] = cg::reduce(warp, accs[0], op); + + __shared__ U shared_accumulators[32]; + if (warp.thread_rank() == 0) { + shared_accumulators[warp.meta_group_rank()] = accs[0]; + } + block.sync(); + accs[0] = (warp.thread_rank() < warp.meta_group_size()) + ? shared_accumulators[warp.thread_rank()] + : init; + accs[0] = cg::reduce(warp, accs[0], op); + + if (block.thread_rank() == 0) { + out[grid.block_rank()] = accs[0]; + } +} + +} // namespace cu + +void all_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + constexpr int N_READS = 4; + + out.set_data(allocator::malloc(out.nbytes())); + + auto get_args = [](size_t size, int N) { + size_t reductions = size / N; + int threads = 1024; + int blocks = std::min(1024UL, (reductions + threads - 1) / threads); + size_t reductions_per_block = std::max( + static_cast(threads), (reductions + blocks - 1) / blocks); + size_t block_step = reductions_per_block * N_READS; + + return std::make_tuple(blocks, threads, block_step); + }; + + int blocks, threads; + size_t block_step; + bool large = in.size() > N_READS * 1024; + array x = in; + + // Large array so allocate an intermediate and accumulate there + if (large) { + std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + encoder.add_temporary(intermediate); + encoder.set_input_array(x); + encoder.set_output_array(intermediate); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::all_reduce; + kernel<<>>( + x.data(), intermediate.data(), block_step, x.size()); + }); + }); + }); + x = intermediate; + } + + // Final reduction + { + std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); + encoder.set_input_array(x); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::all_reduce; + kernel<<>>( + x.data(), out.data(), block_step, x.size()); + }); + }); + }); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index a673e052e..07041efce 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -47,6 +47,12 @@ namespace mlx::core { throw std::invalid_argument("Unknown reduce type."); \ } +void all_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + void segmented_reduce( cu::CommandEncoder& encoder, const array& in, diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 832787222..1637d66f6 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -6,17 +6,65 @@ namespace mlx::core::cu { +template +struct uint_by_size; +template <> +struct uint_by_size<2> { + using type = uint16_t; +}; +template <> +struct uint_by_size<4> { + using type = uint32_t; +}; +template <> +struct uint_by_size<8> { + using type = unsigned long long int; +}; + +template +__device__ void atomic_reduce(T* x, T y) { + if constexpr (sizeof(T) == 1) { + using U = uint16_t; + U* x_int = (U*)((char*)x - ((size_t)x % 2)); + int shift = ((char*)x - (char*)x_int) * 8; + int mask = 0xff << shift; + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(static_cast((old_val >> shift) & 0xff), y); + new_val = (old_val & ~mask) | (result << shift); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } else { + using U = typename uint_by_size::type; + U* x_int = (U*)(x); + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(*((T*)&old_val), y); + new_val = *((U*)&result); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } +} + // Reduce ops. struct And { __device__ bool operator()(bool a, bool b) { return a && b; } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } }; struct Or { __device__ bool operator()(bool a, bool b) { return a || b; } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } }; struct Sum { @@ -24,6 +72,23 @@ struct Sum { __device__ T operator()(T a, T b) { return a + b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } + + __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { + atomicAdd(x, y); + } + + __device__ void atomic_update(int* x, int y) { + atomicAdd(x, y); + } + + __device__ void atomic_update(float* x, float y) { + atomicAdd(x, y); + } }; struct Prod { @@ -31,6 +96,11 @@ struct Prod { __device__ T operator()(T a, T b) { return a * b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; struct Min { @@ -38,6 +108,11 @@ struct Min { __device__ T operator()(T a, T b) { return a < b ? a : b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; struct Max { @@ -45,6 +120,11 @@ struct Max { __device__ T operator()(T a, T b) { return a > b ? a : b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; // Traits to get the result type of reduce op. @@ -120,7 +200,7 @@ template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (cuda::std::is_same_v) { - return T{1, 1}; + return T{1, 0}; } else { return typename ReduceResult::type{1}; }