mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
@@ -98,16 +98,13 @@ class QuantizedEmbedding(Module):
|
||||
self.freeze()
|
||||
|
||||
def __call__(self, x):
|
||||
s = x.shape
|
||||
x = x.flatten()
|
||||
out = mx.dequantize(
|
||||
return mx.dequantize(
|
||||
self["weight"][x],
|
||||
scales=self["scales"][x],
|
||||
biases=self["biases"][x],
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
return out.reshape(*s, -1)
|
||||
|
||||
def as_linear(self, x):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user