Faster small batch qmv (#1861)

* faster small batch qmv

* swap batch and block dims for qvm and qmv regular
This commit is contained in:
Awni Hannun
2025-02-12 22:02:36 -08:00
committed by GitHub
parent d274ae77f2
commit e425dc00c0
2 changed files with 15 additions and 15 deletions

View File

@@ -167,7 +167,7 @@ void qvm_split_k(
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
MTL::Size grid_dims = MTL::Size(B, O / bo, N);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
@@ -313,13 +313,13 @@ void qmm_op(
int bo = 8;
int bd = 32;
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N);
grid_dims = MTL::Size(B, O / bo, N);
} else if (B < 6) {
name += "qmv";
int bo = 8;
int bd = 32;
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
} else {
int wn = 2;
int wm = 2;
@@ -339,7 +339,7 @@ void qmm_op(
int bo = 64;
int bd = 32;
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N);
grid_dims = MTL::Size(B, O / bo, N);
} else {
name += "qmm_n";
int wn = 2;