Simplifying and improving qmm (#1030)

This commit is contained in:
Angelos Katharopoulos
2024-04-24 13:07:45 -07:00
committed by GitHub
parent ec8578d41a
commit 20a01bbd9f
2 changed files with 262 additions and 202 deletions

View File

@@ -110,7 +110,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 64;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
@@ -167,7 +167,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 64;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);