mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix dispatch threads for a few kernels (#1594)
This commit is contained in:
@@ -437,8 +437,7 @@ void steel_matmul(
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@@ -955,8 +954,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user