Fix quant in gguf (#698)

* fix quant in gguf

* fix whisper
This commit is contained in:
Awni Hannun
2024-04-19 20:07:11 -07:00
committed by GitHub
parent 574ad7f6fe
commit 6abdbe3be8
2 changed files with 7 additions and 3 deletions

View File

@@ -27,13 +27,17 @@ def load_model(
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
nn.quantize(model, **quantization)
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model