mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix dequantization (#693)
This commit is contained in:
parent
2146bcd7ee
commit
574ad7f6fe
@ -147,6 +147,19 @@ def dequantize(model: nn.Module) -> nn.Module:
|
|||||||
if bias:
|
if bias:
|
||||||
linear.bias = module.bias
|
linear.bias = module.bias
|
||||||
de_quantize_layers.append((name, linear))
|
de_quantize_layers.append((name, linear))
|
||||||
|
if isinstance(module, nn.QuantizedEmbedding):
|
||||||
|
weight = mx.dequantize(
|
||||||
|
module.weight,
|
||||||
|
module.scales,
|
||||||
|
module.biases,
|
||||||
|
module.group_size,
|
||||||
|
module.bits,
|
||||||
|
).astype(mx.float16)
|
||||||
|
num_embeddings, dims = weight.shape
|
||||||
|
emb = nn.Embedding(num_embeddings, dims)
|
||||||
|
emb.weight = weight
|
||||||
|
de_quantize_layers.append((name, emb))
|
||||||
|
|
||||||
if len(de_quantize_layers) > 0:
|
if len(de_quantize_layers) > 0:
|
||||||
model.update_modules(tree_unflatten(de_quantize_layers))
|
model.update_modules(tree_unflatten(de_quantize_layers))
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user