mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster small batch qmv (#1861)
* faster small batch qmv * swap batch and block dims for qvm and qmv regular
This commit is contained in:
@@ -167,7 +167,7 @@ void qvm_split_k(
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
MTL::Size grid_dims = MTL::Size(B, O / bo, N);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
@@ -313,13 +313,13 @@ void qmm_op(
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size(O / bo, B, N);
|
||||
grid_dims = MTL::Size(B, O / bo, N);
|
||||
} else if (B < 6) {
|
||||
name += "qmv";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
|
||||
} else {
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
@@ -339,7 +339,7 @@ void qmm_op(
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size(O / bo, B, N);
|
||||
grid_dims = MTL::Size(B, O / bo, N);
|
||||
} else {
|
||||
name += "qmm_n";
|
||||
int wn = 2;
|
||||
|
||||
Reference in New Issue
Block a user