safely divide for 0 size inputs (#388)

This commit is contained in:
Awni Hannun 2024-01-07 00:19:54 -08:00 committed by GitHub
parent b34bf5d52b
commit c6d2878c1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -21,6 +21,10 @@ namespace mlx::core {
namespace { namespace {
inline auto safe_divup(size_t n, size_t m) {
return m == 0 ? 0 : (n + m - 1) / m;
}
// All Reduce // All Reduce
void all_reduce_dispatch( void all_reduce_dispatch(
const array& in, const array& in,
@ -52,8 +56,7 @@ void all_reduce_dispatch(
mod_in_size > thread_group_size ? thread_group_size : mod_in_size; mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
// If the number of thread groups needed exceeds 1024, we reuse threads groups // If the number of thread groups needed exceeds 1024, we reuse threads groups
uint n_thread_groups = uint n_thread_groups = safe_divup(mod_in_size, thread_group_size);
(mod_in_size + thread_group_size - 1) / thread_group_size;
n_thread_groups = std::min(n_thread_groups, 1024u); n_thread_groups = std::min(n_thread_groups, 1024u);
uint nthreads = n_thread_groups * thread_group_size; uint nthreads = n_thread_groups * thread_group_size;