diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4a53ee9d..ab5b99af 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -226,11 +226,25 @@ def load_model(model_path: Path) -> nn.Module: model = model_class(model_args) if quantization is not None: - nn.QuantizedLinear.quantize_module( - model, - **quantization, - linear_class_predicate=linear_class_predicate, - ) + # for legacy models that don't have lm_head quant due to non-32 dims + if "lm_head.scales" not in weights.keys(): + vocab_size = config["vocab_size"] + extended_linear_class_predicate = ( + lambda layer: linear_class_predicate(layer) + and layer.weight.shape[0] != vocab_size + ) + nn.QuantizedLinear.quantize_module( + model, + **quantization, + linear_class_predicate=extended_linear_class_predicate, + ) + # for models that have lm_head quant + else: + nn.QuantizedLinear.quantize_module( + model, + **quantization, + linear_class_predicate=linear_class_predicate, + ) model.load_weights(list(weights.items()))