From c6d2878c1a28f0f6e72e13b9d0c0b1064e95b521 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 7 Jan 2024 00:19:54 -0800 Subject: [PATCH] safely divide for 0 size inputs (#388) --- mlx/backend/metal/reduce.cpp | 7 +++++-- tests/ops_tests.cpp | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 9da5c79bf..012038178 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -21,6 +21,10 @@ namespace mlx::core { namespace { +inline auto safe_divup(size_t n, size_t m) { + return m == 0 ? 0 : (n + m - 1) / m; +} + // All Reduce void all_reduce_dispatch( const array& in, @@ -52,8 +56,7 @@ void all_reduce_dispatch( mod_in_size > thread_group_size ? thread_group_size : mod_in_size; // If the number of thread groups needed exceeds 1024, we reuse threads groups - uint n_thread_groups = - (mod_in_size + thread_group_size - 1) / thread_group_size; + uint n_thread_groups = safe_divup(mod_in_size, thread_group_size); n_thread_groups = std::min(n_thread_groups, 1024u); uint nthreads = n_thread_groups * thread_group_size; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index ad767f525..0a1dea0ad 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2314,4 +2314,4 @@ TEST_CASE("tensordot") { 65370.}, {3, 6}); CHECK(array_equal(z, expected).item()); -} \ No newline at end of file +}