Fix quant in gguf (#698)

* fix quant in gguf

* fix whisper
This commit is contained in:
Awni Hannun 2024-04-19 20:07:11 -07:00 committed by GitHub
parent 574ad7f6fe
commit 6abdbe3be8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 3 deletions

View File

@ -285,7 +285,7 @@ def load(gguf_file: str, repo: str = None):
and f"{p}.scales" in weights
)
nn.quantize(
qm,
model,
**quantization,
class_predicate=class_predicate,
)

View File

@ -27,13 +27,17 @@ def load_model(
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
nn.quantize(model, **quantization)
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model