mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
quantize linear (#250)
This commit is contained in:
parent
737b4c81a3
commit
485fb9ac0f
@ -28,8 +28,6 @@ def quantize(weights, config, args):
|
|||||||
model,
|
model,
|
||||||
args.q_group_size,
|
args.q_group_size,
|
||||||
args.q_bits,
|
args.q_bits,
|
||||||
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
|
||||||
and m.weight.shape[0] != config["vocab_size"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the config:
|
# Update the config:
|
||||||
|
@ -339,9 +339,6 @@ def load_model(folder: str):
|
|||||||
model_args = ModelArgs(**config)
|
model_args = ModelArgs(**config)
|
||||||
model = Model(model_args)
|
model = Model(model_args)
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
quantization["linear_class_predicate"] = lambda m: isinstance(
|
|
||||||
m, nn.Linear
|
|
||||||
) and (m.weight.shape[0] != model_args.vocab_size)
|
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||||
|
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
|
Loading…
Reference in New Issue
Block a user