mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	safely divide for 0 size inputs (#388)
This commit is contained in:
		| @@ -21,6 +21,10 @@ namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| inline auto safe_divup(size_t n, size_t m) { | ||||
|   return m == 0 ? 0 : (n + m - 1) / m; | ||||
| } | ||||
|  | ||||
| // All Reduce | ||||
| void all_reduce_dispatch( | ||||
|     const array& in, | ||||
| @@ -52,8 +56,7 @@ void all_reduce_dispatch( | ||||
|       mod_in_size > thread_group_size ? thread_group_size : mod_in_size; | ||||
|  | ||||
|   // If the number of thread groups needed exceeds 1024, we reuse threads groups | ||||
|   uint n_thread_groups = | ||||
|       (mod_in_size + thread_group_size - 1) / thread_group_size; | ||||
|   uint n_thread_groups = safe_divup(mod_in_size, thread_group_size); | ||||
|   n_thread_groups = std::min(n_thread_groups, 1024u); | ||||
|   uint nthreads = n_thread_groups * thread_group_size; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun