diff --git a/llms/gguf_llm/models.py b/llms/gguf_llm/models.py index b0d07558..cc9b3f0e 100644 --- a/llms/gguf_llm/models.py +++ b/llms/gguf_llm/models.py @@ -285,7 +285,7 @@ def load(gguf_file: str, repo: str = None): and f"{p}.scales" in weights ) nn.quantize( - qm, + model, **quantization, class_predicate=class_predicate, ) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 2b7efaf0..6705385d 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -27,13 +27,17 @@ def load_model( model_args = whisper.ModelDimensions(**config) weights = mx.load(str(model_path / "weights.npz")) - weights = tree_unflatten(list(weights.items())) model = whisper.Whisper(model_args, dtype) if quantization is not None: - nn.quantize(model, **quantization) + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize(model, **quantization, class_predicate=class_predicate) + weights = tree_unflatten(list(weights.items())) model.update(weights) mx.eval(model.parameters()) return model