Add dropout parameter to lora configuration (#599)

* Add dropout parameter to lora configuration

A dropout parameter has been added to the lora configuration settings in lora_config.yaml. The LoRALinear class in utils.py has been updated to take this new parameter. Additionally, a AttributeError: 'types.SimpleNamespace' object has no attribute 'prompt' related to `args.prompt` has been removed from lora.py.

* Update lora_config.yaml

Set dropout to 0.0 in the sample config file

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Ivan Fioravanti 2024-03-20 16:44:40 +01:00 committed by GitHub
parent 949f63f309
commit d2a99172a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 6 deletions

View File

@ -59,3 +59,4 @@ lora_parameters:
rank: 8
alpha: 16.0
scale: 10.0
dropout: 0.0

View File

@ -235,11 +235,6 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if args.prompt is not None:
raise NotImplementedError(
"Please use mlx_lm.generate with trained adapter for generation."
)
if __name__ == "__main__":
parser = build_parser()

View File

@ -32,7 +32,11 @@ def linear_to_lora_layers(
)
to_lora = lambda lin: LoRALinear.from_linear(
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
lin,
r=config["rank"],
alpha=config["alpha"],
scale=config["scale"],
dropout=config["dropout"],
)
keys = config.get("keys", None)