mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import types
|
||||
@@ -16,7 +15,7 @@ from mlx.utils import tree_flatten
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import build_schedule, linear_to_lora_layers
|
||||
from .utils import load
|
||||
from .utils import load, save_config
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
yaml_loader.add_implicit_resolver(
|
||||
@@ -48,7 +47,7 @@ CONFIG_DEFAULTS = {
|
||||
"steps_per_report": 10,
|
||||
"steps_per_eval": 200,
|
||||
"resume_adapter_file": None,
|
||||
"adapter_file": "adapters.npz",
|
||||
"adapter_path": "adapters",
|
||||
"save_every": 100,
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
@@ -102,12 +101,12 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
help="Load path to resume training with the given adapter weights.",
|
||||
help="Load path to resume training with the given adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Save/load path for the trained adapter weights.",
|
||||
help="Save/load path for the adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
@@ -184,6 +183,11 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
|
||||
if args.train:
|
||||
print("Training")
|
||||
# init training args
|
||||
@@ -194,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=args.adapter_file,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
@@ -219,12 +223,12 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
)
|
||||
|
||||
# Load the LoRA adapter weights which we assume should exist by this point
|
||||
if not Path(args.adapter_file).is_file():
|
||||
if not adapter_file.is_file():
|
||||
raise ValueError(
|
||||
f"Adapter file {args.adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters.npz."
|
||||
f"Adapter file {adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters"
|
||||
)
|
||||
model.load_weights(args.adapter_file, strict=False)
|
||||
model.load_weights(str(adapter_file), strict=False)
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
|
Reference in New Issue
Block a user