Fix boolean all reduce bug (#1355)

This commit is contained in:
Angelos Katharopoulos
2024-08-24 10:09:32 -07:00
committed by GitHub
parent 64bec4fad7
commit 8081df79be
4 changed files with 14 additions and 3 deletions

View File

@@ -308,7 +308,11 @@ void all_reduce_dispatch(
compute_encoder.dispatchThreads(grid_dims, group_dims);
// 2nd pass
compute_encoder->setComputePipelineState(kernel);
std::ostringstream kname_2nd_pass;
kname_2nd_pass << "all_reduce_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass =
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
compute_encoder->setComputePipelineState(kernel_2nd_pass);
size_t intermediate_size = n_rows;
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);