mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 14:08:07 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
@@ -280,24 +280,16 @@ def load(gguf_file: str, repo: str = None):
|
||||
config = get_config(metadata)
|
||||
model = Model(ModelArgs(**config))
|
||||
if quantization is not None:
|
||||
# quantized the LM head?
|
||||
qm = model if "lm_head.scales" in weights else model.model
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(
|
||||
qm,
|
||||
**quantization,
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
def dequantize(k):
|
||||
weight = weights.pop(f"{k}.weight")
|
||||
scales = weights.pop(f"{k}.scales")
|
||||
biases = weights.pop(f"{k}.biases")
|
||||
weights[f"{k}.weight"] = mx.dequantize(
|
||||
weight, scales=scales, biases=biases, **quantization
|
||||
)
|
||||
|
||||
# Dequantize embeddings
|
||||
dequantize("model.embed_tokens")
|
||||
|
||||
tokenizer = GGUFTokenizer(metadata)
|
||||
model.load_weights(list(weights.items()))
|
||||
return model, tokenizer
|
||||
|
Reference in New Issue
Block a user