diff --git a/lora/convert.py b/lora/convert.py index 26928c96..d48c6b87 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -18,12 +18,10 @@ def quantize(weights, config, args): model.load_weights(list(weights.items())) # Quantize the model: - nn.QuantizedLinear.quantize_module( + nn.quantize( model, args.q_group_size, args.q_bits, - linear_class_predicate=lambda m: isinstance(m, nn.Linear) - and m.weight.shape[0] != 8, ) # Update the config: diff --git a/lora/utils.py b/lora/utils.py index 0e7c1fb9..e768e9a8 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -147,11 +147,14 @@ def load(path_or_hf_repo: str): model_args = models.ModelArgs.from_dict(config) model = models.Model(model_args) if quantization is not None: - nn.QuantizedLinear.quantize_module( + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize( model, **quantization, - linear_class_predicate=lambda m: isinstance(m, nn.Linear) - and m.weight.shape[0] != 8, + class_predicate=class_predicate, ) model.load_weights(list(weights.items()))