From b70a964cdea85cd9932d9b4475373fb9e4d2ca91 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 18 Jun 2025 14:33:27 -0700 Subject: [PATCH] Optimize all reduce a bit --- mlx/backend/cuda/reduce/all_reduce.cu | 48 ++++++++++++++++---------- mlx/backend/cuda/reduce/reduce_ops.cuh | 12 +++---- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 64a793f14..fd5ec256f 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -102,8 +102,20 @@ void all_reduce( auto get_args = [](size_t size, int N) { size_t reductions = size / N; - int threads = 1024; - int blocks = std::min(1024UL, (reductions + threads - 1) / threads); + int threads = 512; + size_t full_blocks = (reductions + threads - 1) / threads; + int blocks; + if (full_blocks < 32) { + blocks = 1; + } else if (full_blocks < 128) { + blocks = 32; + } else if (full_blocks < 512) { + blocks = 128; + } else if (full_blocks < 1024) { + blocks = 512; + } else { + blocks = 1024; + } size_t reductions_per_block = std::max( static_cast(threads), (reductions + blocks - 1) / blocks); size_t block_step = reductions_per_block * N_READS; @@ -117,7 +129,8 @@ void all_reduce( array x = in; // Large array so allocate an intermediate and accumulate there - if (large) { + std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); + if (blocks > 1) { std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); array intermediate({blocks}, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); @@ -135,26 +148,25 @@ void all_reduce( }); }); }); + + // Set the input for the next step and recalculate the blocks x = intermediate; + std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); } - // Final reduction - { - std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); - encoder.set_input_array(x); - encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; - auto kernel = cu::all_reduce; - kernel<<>>( - x.data(), out.data(), block_step, x.size()); - }); + encoder.set_input_array(x); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::all_reduce; + kernel<<>>( + x.data(), out.data(), block_step, x.size()); }); }); - } + }); } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 1637d66f6..c19a801de 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -48,7 +48,7 @@ __device__ void atomic_reduce(T* x, T y) { // Reduce ops. struct And { - __device__ bool operator()(bool a, bool b) { + __device__ __forceinline__ bool operator()(bool a, bool b) { return a && b; } @@ -58,7 +58,7 @@ struct And { }; struct Or { - __device__ bool operator()(bool a, bool b) { + __device__ __forceinline__ bool operator()(bool a, bool b) { return a || b; } @@ -69,7 +69,7 @@ struct Or { struct Sum { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a + b; } @@ -93,7 +93,7 @@ struct Sum { struct Prod { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a * b; } @@ -105,7 +105,7 @@ struct Prod { struct Min { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a < b ? a : b; } @@ -117,7 +117,7 @@ struct Min { struct Max { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a > b ? a : b; }