Optimize all reduce a bit

This commit is contained in:
Angelos Katharopoulos 2025-06-18 14:33:27 -07:00
parent 9cf7ef1068
commit b70a964cde
2 changed files with 36 additions and 24 deletions

View File

@ -102,8 +102,20 @@ void all_reduce(
auto get_args = [](size_t size, int N) { auto get_args = [](size_t size, int N) {
size_t reductions = size / N; size_t reductions = size / N;
int threads = 1024; int threads = 512;
int blocks = std::min(1024UL, (reductions + threads - 1) / threads); size_t full_blocks = (reductions + threads - 1) / threads;
int blocks;
if (full_blocks < 32) {
blocks = 1;
} else if (full_blocks < 128) {
blocks = 32;
} else if (full_blocks < 512) {
blocks = 128;
} else if (full_blocks < 1024) {
blocks = 512;
} else {
blocks = 1024;
}
size_t reductions_per_block = std::max( size_t reductions_per_block = std::max(
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks); static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
size_t block_step = reductions_per_block * N_READS; size_t block_step = reductions_per_block * N_READS;
@ -117,7 +129,8 @@ void all_reduce(
array x = in; array x = in;
// Large array so allocate an intermediate and accumulate there // Large array so allocate an intermediate and accumulate there
if (large) { std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
if (blocks > 1) {
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
array intermediate({blocks}, out.dtype(), nullptr, {}); array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes())); intermediate.set_data(allocator::malloc(intermediate.nbytes()));
@ -135,12 +148,12 @@ void all_reduce(
}); });
}); });
}); });
// Set the input for the next step and recalculate the blocks
x = intermediate; x = intermediate;
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
} }
// Final reduction
{
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
@ -154,7 +167,6 @@ void all_reduce(
}); });
}); });
}); });
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -48,7 +48,7 @@ __device__ void atomic_reduce(T* x, T y) {
// Reduce ops. // Reduce ops.
struct And { struct And {
__device__ bool operator()(bool a, bool b) { __device__ __forceinline__ bool operator()(bool a, bool b) {
return a && b; return a && b;
} }
@ -58,7 +58,7 @@ struct And {
}; };
struct Or { struct Or {
__device__ bool operator()(bool a, bool b) { __device__ __forceinline__ bool operator()(bool a, bool b) {
return a || b; return a || b;
} }
@ -69,7 +69,7 @@ struct Or {
struct Sum { struct Sum {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a + b; return a + b;
} }
@ -93,7 +93,7 @@ struct Sum {
struct Prod { struct Prod {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a * b; return a * b;
} }
@ -105,7 +105,7 @@ struct Prod {
struct Min { struct Min {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a < b ? a : b; return a < b ? a : b;
} }
@ -117,7 +117,7 @@ struct Min {
struct Max { struct Max {
template <typename T> template <typename T>
__device__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
return a > b ? a : b; return a > b ? a : b;
} }