From e5dfef5d9aa60c80f3e12f158da49147a366f2d2 Mon Sep 17 00:00:00 2001 From: Madroid Ma Date: Tue, 27 Feb 2024 11:35:04 +0800 Subject: [PATCH] LoRA: Extract the run function for easy use in scripts file (#482) * LoRA: Extract the run_lora function for easy use in scripts * LoRA: run_lora function adds a TrainingCallback pass. * LoRA: change run_lora to run --- llms/mlx_lm/lora.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index a8a2912e..e433bbd6 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -7,7 +7,7 @@ import mlx.optimizers as optim import numpy as np from mlx.utils import tree_flatten -from .tuner.trainer import TrainingArgs, evaluate, train +from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import linear_to_lora_layers from .utils import generate, load @@ -160,10 +160,7 @@ def load_dataset(args): return train, valid, test -if __name__ == "__main__": - parser = build_parser() - args = parser.parse_args() - +def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) print("Loading pretrained model") @@ -209,6 +206,7 @@ if __name__ == "__main__": optimizer=opt, train_dataset=train_set, val_dataset=valid_set, + training_callback=training_callback, ) # Load the LoRA adapter weights which we assume should exist by this point @@ -246,3 +244,10 @@ if __name__ == "__main__": prompt=args.prompt, verbose=True, ) + + +if __name__ == "__main__": + parser = build_parser() + args = parser.parse_args() + + run(args)