Optimize all reduce a bit

This commit is contained in:
Angelos Katharopoulos 2025-06-18 14:33:27 -07:00
parent 9cf7ef1068
commit b70a964cde
2 changed files with 36 additions and 24 deletions

View File

@ -102,8 +102,20 @@ void all_reduce(
auto get_args = [](size_t size, int N) {
size_t reductions = size / N;
int threads = 1024;
int blocks = std::min(1024UL, (reductions + threads - 1) / threads);
int threads = 512;
size_t full_blocks = (reductions + threads - 1) / threads;
int blocks;
if (full_blocks < 32) {
blocks = 1;
} else if (full_blocks < 128) {
blocks = 32;
} else if (full_blocks < 512) {
blocks = 128;
} else if (full_blocks < 1024) {
blocks = 512;
} else {
blocks = 1024;
}
size_t reductions_per_block = std::max(
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
size_t block_step = reductions_per_block * N_READS;
@ -117,7 +129,8 @@ void all_reduce(
array x = in;
// Large array so allocate an intermediate and accumulate there
if (large) {
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
if (blocks > 1) {
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
@ -135,26 +148,25 @@ void all_reduce(
});
});
});
// Set the input for the next step and recalculate the blocks
x = intermediate;
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
}
// Final reduction
{
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), out.data<U>(), block_step, x.size());
});
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), out.data<U>(), block_step, x.size());
});
});
}
});
}
} // namespace mlx::core

View File

@ -48,7 +48,7 @@ __device__ void atomic_reduce(T* x, T y) {
// Reduce ops.
struct And {
__device__ bool operator()(bool a, bool b) {
__device__ __forceinline__ bool operator()(bool a, bool b) {
return a && b;
}
@ -58,7 +58,7 @@ struct And {
};
struct Or {
__device__ bool operator()(bool a, bool b) {
__device__ __forceinline__ bool operator()(bool a, bool b) {
return a || b;
}
@ -69,7 +69,7 @@ struct Or {
struct Sum {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a + b;
}
@ -93,7 +93,7 @@ struct Sum {
struct Prod {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a * b;
}
@ -105,7 +105,7 @@ struct Prod {
struct Min {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a < b ? a : b;
}
@ -117,7 +117,7 @@ struct Min {
struct Max {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a > b ? a : b;
}