fix(mlx-lm): handle legacy quant models (#369)

This commit is contained in:
Anchen 2024-01-25 02:44:05 +11:00 committed by GitHub
parent ab91ac1075
commit 5fc8668a53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -226,11 +226,25 @@ def load_model(model_path: Path) -> nn.Module:
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
# for legacy models that don't have lm_head quant due to non-32 dims
if "lm_head.scales" not in weights.keys():
vocab_size = config["vocab_size"]
extended_linear_class_predicate = (
lambda layer: linear_class_predicate(layer)
and layer.weight.shape[0] != vocab_size
)
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=extended_linear_class_predicate,
)
# for models that have lm_head quant
else:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
model.load_weights(list(weights.items()))