mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Simplifying and improving qmm (#1030)
This commit is contained in:
committed by
GitHub
parent
ec8578d41a
commit
20a01bbd9f
@@ -110,7 +110,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 64;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
@@ -167,7 +167,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 64;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
Reference in New Issue
Block a user