mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Optimize all reduce a bit
This commit is contained in:
parent
9cf7ef1068
commit
b70a964cde
@ -102,8 +102,20 @@ void all_reduce(
|
||||
|
||||
auto get_args = [](size_t size, int N) {
|
||||
size_t reductions = size / N;
|
||||
int threads = 1024;
|
||||
int blocks = std::min(1024UL, (reductions + threads - 1) / threads);
|
||||
int threads = 512;
|
||||
size_t full_blocks = (reductions + threads - 1) / threads;
|
||||
int blocks;
|
||||
if (full_blocks < 32) {
|
||||
blocks = 1;
|
||||
} else if (full_blocks < 128) {
|
||||
blocks = 32;
|
||||
} else if (full_blocks < 512) {
|
||||
blocks = 128;
|
||||
} else if (full_blocks < 1024) {
|
||||
blocks = 512;
|
||||
} else {
|
||||
blocks = 1024;
|
||||
}
|
||||
size_t reductions_per_block = std::max(
|
||||
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
||||
size_t block_step = reductions_per_block * N_READS;
|
||||
@ -117,7 +129,8 @@ void all_reduce(
|
||||
array x = in;
|
||||
|
||||
// Large array so allocate an intermediate and accumulate there
|
||||
if (large) {
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
if (blocks > 1) {
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
@ -135,26 +148,25 @@ void all_reduce(
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Set the input for the next step and recalculate the blocks
|
||||
x = intermediate;
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
}
|
||||
|
||||
// Final reduction
|
||||
{
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), block_step, x.size());
|
||||
});
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), block_step, x.size());
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -48,7 +48,7 @@ __device__ void atomic_reduce(T* x, T y) {
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ struct And {
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
|
||||
@ -69,7 +69,7 @@ struct Or {
|
||||
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ struct Sum {
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ struct Prod {
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
@ -117,7 +117,7 @@ struct Min {
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user