From 5fc8668a539e47318e2e797ffcc313943cca262e Mon Sep 17 00:00:00 2001
From: Anchen
Date: Thu, 25 Jan 2024 02:44:05 +1100
Subject: [PATCH] fix(mlx-lm): handle legacy quant models (#369)
---
llms/mlx_lm/utils.py | 24 +++++++++++++++++++-----
1 file changed, 19 insertions(+), 5 deletions(-)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 4a53ee9d..ab5b99af 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -226,11 +226,25 @@ def load_model(model_path: Path) -> nn.Module:
model = model_class(model_args)
if quantization is not None:
- nn.QuantizedLinear.quantize_module(
- model,
- **quantization,
- linear_class_predicate=linear_class_predicate,
- )
+ # for legacy models that don't have lm_head quant due to non-32 dims
+ if "lm_head.scales" not in weights.keys():
+ vocab_size = config["vocab_size"]
+ extended_linear_class_predicate = (
+ lambda layer: linear_class_predicate(layer)
+ and layer.weight.shape[0] != vocab_size
+ )
+ nn.QuantizedLinear.quantize_module(
+ model,
+ **quantization,
+ linear_class_predicate=extended_linear_class_predicate,
+ )
+ # for models that have lm_head quant
+ else:
+ nn.QuantizedLinear.quantize_module(
+ model,
+ **quantization,
+ linear_class_predicate=linear_class_predicate,
+ )
model.load_weights(list(weights.items()))