mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
safely divide for 0 size inputs (#388)
This commit is contained in:
parent
b34bf5d52b
commit
c6d2878c1a
@ -21,6 +21,10 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline auto safe_divup(size_t n, size_t m) {
|
||||
return m == 0 ? 0 : (n + m - 1) / m;
|
||||
}
|
||||
|
||||
// All Reduce
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
@ -52,8 +56,7 @@ void all_reduce_dispatch(
|
||||
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||
|
||||
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
||||
uint n_thread_groups =
|
||||
(mod_in_size + thread_group_size - 1) / thread_group_size;
|
||||
uint n_thread_groups = safe_divup(mod_in_size, thread_group_size);
|
||||
n_thread_groups = std::min(n_thread_groups, 1024u);
|
||||
uint nthreads = n_thread_groups * thread_group_size;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user