Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun
2024-04-18 18:16:10 -07:00
committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
28 changed files with 108 additions and 190 deletions

View File

@@ -280,24 +280,16 @@ def load(gguf_file: str, repo: str = None):
config = get_config(metadata)
model = Model(ModelArgs(**config))
if quantization is not None:
# quantized the LM head?
qm = model if "lm_head.scales" in weights else model.model
nn.QuantizedLinear.quantize_module(
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(
qm,
**quantization,
class_predicate=class_predicate,
)
def dequantize(k):
weight = weights.pop(f"{k}.weight")
scales = weights.pop(f"{k}.scales")
biases = weights.pop(f"{k}.biases")
weights[f"{k}.weight"] = mx.dequantize(
weight, scales=scales, biases=biases, **quantization
)
# Dequantize embeddings
dequantize("model.embed_tokens")
tokenizer = GGUFTokenizer(metadata)
model.load_weights(list(weights.items()))
return model, tokenizer