LoRA: Extract the run function for easy use in scripts file (#482)

* LoRA: Extract the run_lora function for easy use in scripts

* LoRA: run_lora function adds a TrainingCallback pass.

* LoRA: change run_lora to run
This commit is contained in:
Madroid Ma 2024-02-27 11:35:04 +08:00 committed by GitHub
parent ccb278bcbd
commit e5dfef5d9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,7 +7,7 @@ import mlx.optimizers as optim
import numpy as np import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from .tuner.trainer import TrainingArgs, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers from .tuner.utils import linear_to_lora_layers
from .utils import generate, load from .utils import generate, load
@ -160,10 +160,7 @@ def load_dataset(args):
return train, valid, test return train, valid, test
if __name__ == "__main__": def run(args, training_callback: TrainingCallback = None):
parser = build_parser()
args = parser.parse_args()
np.random.seed(args.seed) np.random.seed(args.seed)
print("Loading pretrained model") print("Loading pretrained model")
@ -209,6 +206,7 @@ if __name__ == "__main__":
optimizer=opt, optimizer=opt,
train_dataset=train_set, train_dataset=train_set,
val_dataset=valid_set, val_dataset=valid_set,
training_callback=training_callback,
) )
# Load the LoRA adapter weights which we assume should exist by this point # Load the LoRA adapter weights which we assume should exist by this point
@ -246,3 +244,10 @@ if __name__ == "__main__":
prompt=args.prompt, prompt=args.prompt,
verbose=True, verbose=True,
) )
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
run(args)