From d2a99172a692b25f8ca725633f4712ce49cca280 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 20 Mar 2024 16:44:40 +0100 Subject: [PATCH] Add dropout parameter to lora configuration (#599) * Add dropout parameter to lora configuration A dropout parameter has been added to the lora configuration settings in lora_config.yaml. The LoRALinear class in utils.py has been updated to take this new parameter. Additionally, a AttributeError: 'types.SimpleNamespace' object has no attribute 'prompt' related to `args.prompt` has been removed from lora.py. * Update lora_config.yaml Set dropout to 0.0 in the sample config file * format --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/examples/lora_config.yaml | 1 + llms/mlx_lm/lora.py | 5 ----- llms/mlx_lm/tuner/utils.py | 6 +++++- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index b616aaf4..1585d69e 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -59,3 +59,4 @@ lora_parameters: rank: 8 alpha: 16.0 scale: 10.0 + dropout: 0.0 diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index b89d8f0e..adc426e4 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -235,11 +235,6 @@ def run(args, training_callback: TrainingCallback = None): print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") - if args.prompt is not None: - raise NotImplementedError( - "Please use mlx_lm.generate with trained adapter for generation." - ) - if __name__ == "__main__": parser = build_parser() diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index f5ce4163..b465146c 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -32,7 +32,11 @@ def linear_to_lora_layers( ) to_lora = lambda lin: LoRALinear.from_linear( - lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"] + lin, + r=config["rank"], + alpha=config["alpha"], + scale=config["scale"], + dropout=config["dropout"], ) keys = config.get("keys", None)