mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Fix boolean all reduce bug (#1355)
This commit is contained in:

committed by
GitHub

parent
64bec4fad7
commit
8081df79be
@@ -308,7 +308,11 @@ void all_reduce_dispatch(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// 2nd pass
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
std::ostringstream kname_2nd_pass;
|
||||
kname_2nd_pass << "all_reduce_" << op_name << type_to_name(intermediate);
|
||||
auto kernel_2nd_pass =
|
||||
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
|
||||
compute_encoder->setComputePipelineState(kernel_2nd_pass);
|
||||
size_t intermediate_size = n_rows;
|
||||
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||
|
Reference in New Issue
Block a user