fix dispatch threads for a few kernels (#1594)

This commit is contained in:
Awni Hannun
2024-11-18 08:35:25 -08:00
committed by GitHub
parent 16ec0556a0
commit 6931f84412
4 changed files with 13 additions and 12 deletions

View File

@@ -437,8 +437,7 @@ void steel_matmul(
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
@@ -955,8 +954,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}