Improve qvm speed (#1140)

This commit is contained in:
Angelos Katharopoulos
2024-05-20 09:20:44 -07:00
committed by GitHub
parent 7e5674d8be
commit da83f899bb
2 changed files with 62 additions and 36 deletions

View File

@@ -137,10 +137,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
@@ -393,10 +393,10 @@ void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);