Fix quant in gguf (#698)

* fix quant in gguf

* fix whisper
This commit is contained in:
Awni Hannun 2024-04-19 20:07:11 -07:00 committed by GitHub
parent 574ad7f6fe
commit 6abdbe3be8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 3 deletions

View File

@ -285,7 +285,7 @@ def load(gguf_file: str, repo: str = None):
and f"{p}.scales" in weights and f"{p}.scales" in weights
) )
nn.quantize( nn.quantize(
qm, model,
**quantization, **quantization,
class_predicate=class_predicate, class_predicate=class_predicate,
) )

View File

@ -27,13 +27,17 @@ def load_model(
model_args = whisper.ModelDimensions(**config) model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz")) weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model = whisper.Whisper(model_args, dtype) model = whisper.Whisper(model_args, dtype)
if quantization is not None: if quantization is not None:
nn.quantize(model, **quantization) class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights) model.update(weights)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model return model