YAML configuration for mlx_lm.lora (#503)

* Convert mlx_lm.lora to use YAML configuration

* pre-commit run fixes

* Fix loading of config file

* Remove invalid YAML from doc

* Update command-line options and YAML parameter overriding, per feedback in #503

* Minor wording change

* Positional argument

* Moved config to a (-c/--config) flag

* Removed CLI option defaults (since CLI options take precedence and their defaults are in CONFIG_DEFAULTS)

* pre-commit format updates

* Fix handling of CLI option defaults

* Prevent None values of unspecified CLI options from overwriting values from CONFIG_DEFAULTS

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-03-08 10:57:52 -05:00
committed by GitHub
parent 8b05bb6d18
commit 8c2cf665ed
3 changed files with 129 additions and 28 deletions

View File

@@ -1,22 +1,61 @@
import argparse
import json
import math
import re
import types
from pathlib import Path
import mlx.optimizers as optim
import numpy as np
import yaml
from mlx.utils import tree_flatten
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .utils import load
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"data": "data/",
"seed": 0,
"lora_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
"learning_rate": 1e-5,
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_file": "adapters.npz",
"save_every": 100,
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
}
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
# Generation args
@@ -24,18 +63,14 @@ def build_parser():
"--max-tokens",
"-m",
type=int,
default=100,
help="The maximum number of tokens to generate",
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
)
parser.add_argument("--temp", type=float, help="The sampling temperature")
parser.add_argument(
"--prompt",
"-p",
type=str,
help="The prompt for generation",
default=None,
)
# Training args
@@ -47,56 +82,44 @@ def build_parser():
parser.add_argument(
"--data",
type=str,
default="data/",
help="Directory with {train, valid, test}.jsonl files",
)
parser.add_argument(
"--lora-layers",
type=int,
default=16,
help="Number of layers to fine-tune",
)
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
parser.add_argument(
"--iters", type=int, default=1000, help="Iterations to train for."
)
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument("--iters", type=int, help="Iterations to train for.")
parser.add_argument(
"--val-batches",
type=int,
default=25,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument(
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
)
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
parser.add_argument(
"--steps-per-report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps-per-eval",
type=int,
default=200,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume-adapter-file",
type=str,
default=None,
help="Load path to resume training with the given adapter weights.",
)
parser.add_argument(
"--adapter-file",
type=str,
default="adapters.npz",
help="Save/load path for the trained adapter weights.",
)
parser.add_argument(
"--save-every",
type=int,
default=100,
help="Save the model every N iterations.",
)
parser.add_argument(
@@ -107,16 +130,20 @@ def build_parser():
parser.add_argument(
"--test-batches",
type=int,
default=500,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=2048,
help="Maximum sequence length.",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"-c",
"--config",
default=None,
help="A YAML configuration file with the training options",
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser
@@ -242,5 +269,19 @@ def run(args, training_callback: TrainingCallback = None):
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)
with open(config, "r") as file:
config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments
for k, v in config.items():
if not args.get(k, None):
args[k] = v
run(args)
# Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items():
if not args.get(k, None):
args[k] = v
run(types.SimpleNamespace(**args))