From 5fc8668a539e47318e2e797ffcc313943cca262e Mon Sep 17 00:00:00 2001 From: Anchen Date: Thu, 25 Jan 2024 02:44:05 +1100 Subject: [PATCH] fix(mlx-lm): handle legacy quant models (#369) --- llms/mlx_lm/utils.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4a53ee9d..ab5b99af 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -226,11 +226,25 @@ def load_model(model_path: Path) -> nn.Module: model = model_class(model_args) if quantization is not None: - nn.QuantizedLinear.quantize_module( - model, - **quantization, - linear_class_predicate=linear_class_predicate, - ) + # for legacy models that don't have lm_head quant due to non-32 dims + if "lm_head.scales" not in weights.keys(): + vocab_size = config["vocab_size"] + extended_linear_class_predicate = ( + lambda layer: linear_class_predicate(layer) + and layer.weight.shape[0] != vocab_size + ) + nn.QuantizedLinear.quantize_module( + model, + **quantization, + linear_class_predicate=extended_linear_class_predicate, + ) + # for models that have lm_head quant + else: + nn.QuantizedLinear.quantize_module( + model, + **quantization, + linear_class_predicate=linear_class_predicate, + ) model.load_weights(list(weights.items()))