Fix nan and improve speed for qvm (#903)

This commit is contained in:
Angelos Katharopoulos
2024-03-26 10:41:45 -07:00
committed by GitHub
parent a3ee03da01
commit 9948eddf11
2 changed files with 67 additions and 61 deletions

View File

@@ -137,7 +137,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = std::min(32, O);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);