mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Optimize all reduce a bit
This commit is contained in:
parent
9cf7ef1068
commit
b70a964cde
@ -102,8 +102,20 @@ void all_reduce(
|
|||||||
|
|
||||||
auto get_args = [](size_t size, int N) {
|
auto get_args = [](size_t size, int N) {
|
||||||
size_t reductions = size / N;
|
size_t reductions = size / N;
|
||||||
int threads = 1024;
|
int threads = 512;
|
||||||
int blocks = std::min(1024UL, (reductions + threads - 1) / threads);
|
size_t full_blocks = (reductions + threads - 1) / threads;
|
||||||
|
int blocks;
|
||||||
|
if (full_blocks < 32) {
|
||||||
|
blocks = 1;
|
||||||
|
} else if (full_blocks < 128) {
|
||||||
|
blocks = 32;
|
||||||
|
} else if (full_blocks < 512) {
|
||||||
|
blocks = 128;
|
||||||
|
} else if (full_blocks < 1024) {
|
||||||
|
blocks = 512;
|
||||||
|
} else {
|
||||||
|
blocks = 1024;
|
||||||
|
}
|
||||||
size_t reductions_per_block = std::max(
|
size_t reductions_per_block = std::max(
|
||||||
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
||||||
size_t block_step = reductions_per_block * N_READS;
|
size_t block_step = reductions_per_block * N_READS;
|
||||||
@ -117,7 +129,8 @@ void all_reduce(
|
|||||||
array x = in;
|
array x = in;
|
||||||
|
|
||||||
// Large array so allocate an intermediate and accumulate there
|
// Large array so allocate an intermediate and accumulate there
|
||||||
if (large) {
|
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||||
|
if (blocks > 1) {
|
||||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
@ -135,12 +148,12 @@ void all_reduce(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Set the input for the next step and recalculate the blocks
|
||||||
x = intermediate;
|
x = intermediate;
|
||||||
|
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final reduction
|
|
||||||
{
|
|
||||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
@ -154,7 +167,6 @@ void all_reduce(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -48,7 +48,7 @@ __device__ void atomic_reduce(T* x, T y) {
|
|||||||
|
|
||||||
// Reduce ops.
|
// Reduce ops.
|
||||||
struct And {
|
struct And {
|
||||||
__device__ bool operator()(bool a, bool b) {
|
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||||
return a && b;
|
return a && b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ struct And {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Or {
|
struct Or {
|
||||||
__device__ bool operator()(bool a, bool b) {
|
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||||
return a || b;
|
return a || b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ struct Or {
|
|||||||
|
|
||||||
struct Sum {
|
struct Sum {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ struct Sum {
|
|||||||
|
|
||||||
struct Prod {
|
struct Prod {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a * b;
|
return a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ struct Prod {
|
|||||||
|
|
||||||
struct Min {
|
struct Min {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a < b ? a : b;
|
return a < b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +117,7 @@ struct Min {
|
|||||||
|
|
||||||
struct Max {
|
struct Max {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a > b ? a : b;
|
return a > b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user