Quantized matmul fix (#677)

* Fix qmv for small or unaligned matrices

* Fix qmm
This commit is contained in:
Angelos Katharopoulos
2024-02-12 18:54:21 -08:00
committed by GitHub
parent 4cc70290f7
commit 40c108766b
3 changed files with 81 additions and 9 deletions

View File

@@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bo = std::min(32, O);
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);