one more quant fix (#708)

This commit is contained in:
Awni Hannun 2024-04-22 18:12:52 -07:00 committed by GitHub
parent 8d5cf5b0c8
commit ecbc6ff1e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 6 deletions

View File

@ -18,12 +18,10 @@ def quantize(weights, config, args):
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module( nn.quantize(
model, model,
args.q_group_size, args.q_group_size,
args.q_bits, args.q_bits,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
) )
# Update the config: # Update the config:

View File

@ -147,11 +147,14 @@ def load(path_or_hf_repo: str):
model_args = models.ModelArgs.from_dict(config) model_args = models.ModelArgs.from_dict(config)
model = models.Model(model_args) model = models.Model(model_args)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module( class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(
model, model,
**quantization, **quantization,
linear_class_predicate=lambda m: isinstance(m, nn.Linear) class_predicate=class_predicate,
and m.weight.shape[0] != 8,
) )
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))