quantize linear (#250)

This commit is contained in:
Awni Hannun 2024-01-07 18:48:59 -08:00 committed by GitHub
parent 737b4c81a3
commit 485fb9ac0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 0 additions and 5 deletions

View File

@ -28,8 +28,6 @@ def quantize(weights, config, args):
model, model,
args.q_group_size, args.q_group_size,
args.q_bits, args.q_bits,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != config["vocab_size"],
) )
# Update the config: # Update the config:

View File

@ -339,9 +339,6 @@ def load_model(folder: str):
model_args = ModelArgs(**config) model_args = ModelArgs(**config)
model = Model(model_args) model = Model(model_args)
if quantization is not None: if quantization is not None:
quantization["linear_class_predicate"] = lambda m: isinstance(
m, nn.Linear
) and (m.weight.shape[0] != model_args.vocab_size)
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(model, **quantization)
weights = mx.load(str(model_path / "weights.npz")) weights = mx.load(str(model_path / "weights.npz"))