mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix(mlx-lm): handle legacy quant models (#369)
This commit is contained in:
parent
ab91ac1075
commit
5fc8668a53
@ -226,11 +226,25 @@ def load_model(model_path: Path) -> nn.Module:
|
||||
model = model_class(model_args)
|
||||
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
**quantization,
|
||||
linear_class_predicate=linear_class_predicate,
|
||||
)
|
||||
# for legacy models that don't have lm_head quant due to non-32 dims
|
||||
if "lm_head.scales" not in weights.keys():
|
||||
vocab_size = config["vocab_size"]
|
||||
extended_linear_class_predicate = (
|
||||
lambda layer: linear_class_predicate(layer)
|
||||
and layer.weight.shape[0] != vocab_size
|
||||
)
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
**quantization,
|
||||
linear_class_predicate=extended_linear_class_predicate,
|
||||
)
|
||||
# for models that have lm_head quant
|
||||
else:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
**quantization,
|
||||
linear_class_predicate=linear_class_predicate,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user