mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Improve qvm speed (#1140)
This commit is contained in:
committed by
GitHub
parent
7e5674d8be
commit
da83f899bb
@@ -137,10 +137,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
@@ -393,10 +393,10 @@ void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
|
||||
Reference in New Issue
Block a user