mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
This commit is contained in:
@@ -65,7 +65,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmv kernel
|
||||
@@ -92,7 +92,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_t kernel
|
||||
@@ -123,7 +123,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
} else {
|
||||
// Route to the qvm kernel
|
||||
@@ -150,7 +150,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_n kernel
|
||||
@@ -188,7 +188,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user