Save lora config (#636)

* lora config

* comments

* version bump
This commit is contained in:
Awni Hannun
2024-04-02 13:52:53 -07:00
committed by GitHub
parent d661440dbb
commit 2bd64b78cf
10 changed files with 73 additions and 90 deletions

View File

@@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import math
import re
import types
@@ -16,7 +15,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import build_schedule, linear_to_lora_layers
from .utils import load
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
@@ -48,7 +47,7 @@ CONFIG_DEFAULTS = {
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_file": "adapters.npz",
"adapter_path": "adapters",
"save_every": 100,
"test": False,
"test_batches": 500,
@@ -102,12 +101,12 @@ def build_parser():
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training with the given adapter weights.",
help="Load path to resume training with the given adapters.",
)
parser.add_argument(
"--adapter-file",
"--adapter-path",
type=str,
help="Save/load path for the trained adapter weights.",
help="Save/load path for the adapters.",
)
parser.add_argument(
"--save-every",
@@ -184,6 +183,11 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json")
adapter_file = adapter_path / "adapters.safetensors"
if args.train:
print("Training")
# init training args
@@ -194,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None):
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=args.adapter_file,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
@@ -219,12 +223,12 @@ def run(args, training_callback: TrainingCallback = None):
)
# Load the LoRA adapter weights which we assume should exist by this point
if not Path(args.adapter_file).is_file():
if not adapter_file.is_file():
raise ValueError(
f"Adapter file {args.adapter_file} missing. "
"Use --train to learn and save the adapters.npz."
f"Adapter file {adapter_file} missing. "
"Use --train to learn and save the adapters"
)
model.load_weights(args.adapter_file, strict=False)
model.load_weights(str(adapter_file), strict=False)
if args.test:
print("Testing")