mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix(mlx-lm): handle legacy quant models (#369)
This commit is contained in:
parent
ab91ac1075
commit
5fc8668a53
@ -226,11 +226,25 @@ def load_model(model_path: Path) -> nn.Module:
|
|||||||
model = model_class(model_args)
|
model = model_class(model_args)
|
||||||
|
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.QuantizedLinear.quantize_module(
|
# for legacy models that don't have lm_head quant due to non-32 dims
|
||||||
model,
|
if "lm_head.scales" not in weights.keys():
|
||||||
**quantization,
|
vocab_size = config["vocab_size"]
|
||||||
linear_class_predicate=linear_class_predicate,
|
extended_linear_class_predicate = (
|
||||||
)
|
lambda layer: linear_class_predicate(layer)
|
||||||
|
and layer.weight.shape[0] != vocab_size
|
||||||
|
)
|
||||||
|
nn.QuantizedLinear.quantize_module(
|
||||||
|
model,
|
||||||
|
**quantization,
|
||||||
|
linear_class_predicate=extended_linear_class_predicate,
|
||||||
|
)
|
||||||
|
# for models that have lm_head quant
|
||||||
|
else:
|
||||||
|
nn.QuantizedLinear.quantize_module(
|
||||||
|
model,
|
||||||
|
**quantization,
|
||||||
|
linear_class_predicate=linear_class_predicate,
|
||||||
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user