fix dequantization (#693)

This commit is contained in:
Awni Hannun 2024-04-19 10:46:59 -07:00 committed by GitHub
parent 2146bcd7ee
commit 574ad7f6fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -147,6 +147,19 @@ def dequantize(model: nn.Module) -> nn.Module:
if bias:
linear.bias = module.bias
de_quantize_layers.append((name, linear))
if isinstance(module, nn.QuantizedEmbedding):
weight = mx.dequantize(
module.weight,
module.scales,
module.biases,
module.group_size,
module.bits,
).astype(mx.float16)
num_embeddings, dims = weight.shape
emb = nn.Embedding(num_embeddings, dims)
emb.weight = weight
de_quantize_layers.append((name, emb))
if len(de_quantize_layers) > 0:
model.update_modules(tree_unflatten(de_quantize_layers))
return model