mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
LoRA: Extract the run function for easy use in scripts file (#482)
* LoRA: Extract the run_lora function for easy use in scripts * LoRA: run_lora function adds a TrainingCallback pass. * LoRA: change run_lora to run
This commit is contained in:
parent
ccb278bcbd
commit
e5dfef5d9a
@ -7,7 +7,7 @@ import mlx.optimizers as optim
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import linear_to_lora_layers
|
from .tuner.utils import linear_to_lora_layers
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
@ -160,10 +160,7 @@ def load_dataset(args):
|
|||||||
return train, valid, test
|
return train, valid, test
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def run(args, training_callback: TrainingCallback = None):
|
||||||
parser = build_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
print("Loading pretrained model")
|
print("Loading pretrained model")
|
||||||
@ -209,6 +206,7 @@ if __name__ == "__main__":
|
|||||||
optimizer=opt,
|
optimizer=opt,
|
||||||
train_dataset=train_set,
|
train_dataset=train_set,
|
||||||
val_dataset=valid_set,
|
val_dataset=valid_set,
|
||||||
|
training_callback=training_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load the LoRA adapter weights which we assume should exist by this point
|
# Load the LoRA adapter weights which we assume should exist by this point
|
||||||
@ -246,3 +244,10 @@ if __name__ == "__main__":
|
|||||||
prompt=args.prompt,
|
prompt=args.prompt,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
run(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user