diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index cee5676b..6ab3edb7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -458,13 +458,14 @@ def load_model( weights = model.sanitize(weights) if (quantization := config.get("quantization", None)) is not None: - # Handle legacy models which may not have everything quantized + def class_predicate(p, m): # Handle custom per layer quantizations if p in config["quantization"]: return config["quantization"][p] if not hasattr(m, "to_quantized"): return False + # Handle legacy models which may not have everything quantized return f"{p}.scales" in weights nn.quantize(