mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
one more quant fix (#708)
This commit is contained in:
parent
8d5cf5b0c8
commit
ecbc6ff1e3
@ -18,12 +18,10 @@ def quantize(weights, config, args):
|
|||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
# Quantize the model:
|
# Quantize the model:
|
||||||
nn.QuantizedLinear.quantize_module(
|
nn.quantize(
|
||||||
model,
|
model,
|
||||||
args.q_group_size,
|
args.q_group_size,
|
||||||
args.q_bits,
|
args.q_bits,
|
||||||
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
|
||||||
and m.weight.shape[0] != 8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the config:
|
# Update the config:
|
||||||
|
@ -147,11 +147,14 @@ def load(path_or_hf_repo: str):
|
|||||||
model_args = models.ModelArgs.from_dict(config)
|
model_args = models.ModelArgs.from_dict(config)
|
||||||
model = models.Model(model_args)
|
model = models.Model(model_args)
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.QuantizedLinear.quantize_module(
|
class_predicate = (
|
||||||
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||||
|
and f"{p}.scales" in weights
|
||||||
|
)
|
||||||
|
nn.quantize(
|
||||||
model,
|
model,
|
||||||
**quantization,
|
**quantization,
|
||||||
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
class_predicate=class_predicate,
|
||||||
and m.weight.shape[0] != 8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
|
Loading…
Reference in New Issue
Block a user