diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index a8a2912e..e433bbd6 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -7,7 +7,7 @@ import mlx.optimizers as optim import numpy as np from mlx.utils import tree_flatten -from .tuner.trainer import TrainingArgs, evaluate, train +from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import linear_to_lora_layers from .utils import generate, load @@ -160,10 +160,7 @@ def load_dataset(args): return train, valid, test -if __name__ == "__main__": - parser = build_parser() - args = parser.parse_args() - +def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) print("Loading pretrained model") @@ -209,6 +206,7 @@ if __name__ == "__main__": optimizer=opt, train_dataset=train_set, val_dataset=valid_set, + training_callback=training_callback, ) # Load the LoRA adapter weights which we assume should exist by this point @@ -246,3 +244,10 @@ if __name__ == "__main__": prompt=args.prompt, verbose=True, ) + + +if __name__ == "__main__": + parser = build_parser() + args = parser.parse_args() + + run(args)