No reshapes in quantized embedding (#1682)

* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
This commit is contained in:
Awni Hannun
2024-12-09 18:57:38 -08:00
committed by GitHub
parent 87d7a2520e
commit 29a620cab2
6 changed files with 26 additions and 12 deletions

View File

@@ -98,16 +98,13 @@ class QuantizedEmbedding(Module):
self.freeze()
def __call__(self, x):
s = x.shape
x = x.flatten()
out = mx.dequantize(
return mx.dequantize(
self["weight"][x],
scales=self["scales"][x],
biases=self["biases"][x],
group_size=self.group_size,
bits=self.bits,
)
return out.reshape(*s, -1)
def as_linear(self, x):
"""