Support for quantized matmul with w and w^T (#349)

* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
This commit is contained in:
Angelos Katharopoulos
2024-01-03 14:22:36 -08:00
committed by GitHub
parent d7ac050f4b
commit e7f5059fe4
12 changed files with 718 additions and 193 deletions

View File

@@ -81,9 +81,10 @@ class QuantizedLinear(Module):
def __call__(self, x):
x = mx.quantized_matmul(
x,
self.weight.T,
self.weight,
scales=self.scales,
biases=self.biases,
transpose=True,
group_size=self.group_size,
bits=self.bits,
)