From 6abdbe3be8d4814d894a6803864b4f6202546eb9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 19 Apr 2024 20:07:11 -0700 Subject: [PATCH] Fix quant in gguf (#698) * fix quant in gguf * fix whisper --- llms/gguf_llm/models.py | 2 +- whisper/whisper/load_models.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/llms/gguf_llm/models.py b/llms/gguf_llm/models.py index b0d07558..cc9b3f0e 100644 --- a/llms/gguf_llm/models.py +++ b/llms/gguf_llm/models.py @@ -285,7 +285,7 @@ def load(gguf_file: str, repo: str = None): and f"{p}.scales" in weights ) nn.quantize( - qm, + model, **quantization, class_predicate=class_predicate, ) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 2b7efaf0..6705385d 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -27,13 +27,17 @@ def load_model( model_args = whisper.ModelDimensions(**config) weights = mx.load(str(model_path / "weights.npz")) - weights = tree_unflatten(list(weights.items())) model = whisper.Whisper(model_args, dtype) if quantization is not None: - nn.quantize(model, **quantization) + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize(model, **quantization, class_predicate=class_predicate) + weights = tree_unflatten(list(weights.items())) model.update(weights) mx.eval(model.parameters()) return model