diff --git a/lora/convert.py b/lora/convert.py index a9702c03..38342080 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -28,8 +28,6 @@ def quantize(weights, config, args): model, args.q_group_size, args.q_bits, - linear_class_predicate=lambda m: isinstance(m, nn.Linear) - and m.weight.shape[0] != config["vocab_size"], ) # Update the config: diff --git a/lora/lora.py b/lora/lora.py index 3eed73fb..1c8fe17f 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -339,9 +339,6 @@ def load_model(folder: str): model_args = ModelArgs(**config) model = Model(model_args) if quantization is not None: - quantization["linear_class_predicate"] = lambda m: isinstance( - m, nn.Linear - ) and (m.weight.shape[0] != model_args.vocab_size) nn.QuantizedLinear.quantize_module(model, **quantization) weights = mx.load(str(model_path / "weights.npz"))