YAML configuration for mlx_lm.lora (#503)

* Convert mlx_lm.lora to use YAML configuration

* pre-commit run fixes

* Fix loading of config file

* Remove invalid YAML from doc

* Update command-line options and YAML parameter overriding, per feedback in #503

* Minor wording change

* Positional argument

* Moved config to a (-c/--config) flag

* Removed CLI option defaults (since CLI options take precedence and their defaults are in CONFIG_DEFAULTS)

* pre-commit format updates

* Fix handling of CLI option defaults

* Prevent None values of unspecified CLI options from overwriting values from CONFIG_DEFAULTS

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-03-08 10:57:52 -05:00 committed by GitHub
parent 8b05bb6d18
commit 8c2cf665ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 129 additions and 28 deletions

View File

@ -23,7 +23,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
## Run
The main command is `mlx_lm.lora`. To see a full list of options run:
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
```shell
python -m mlx_lm.lora --help
@ -32,6 +32,16 @@ python -m mlx_lm.lora --help
Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted model.
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
[example YAML](examples/lora_config.yaml). For example:
```shell
python -m mlx_lm.lora --config /path/to/config.yaml
```
If command-line flags are also used, they will override the corresponding
values in the config.
### Fine-tune
To fine-tune a model use:
@ -74,7 +84,7 @@ python -m mlx_lm.lora \
### Generate
For generation use mlx_lm.generate:
For generation use `mlx_lm.generate`:
```shell
python -m mlx_lm.generate \

View File

@ -0,0 +1,50 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
# The PRNG seed
seed: 0
# Number of layers to fine-tune
lora_layers: 16
# Minibatch size.
batch_size: 4
# Iterations to train for.
iters: 100
# Number of validation batches, -1 uses the entire validation set.
val_batches: 25
# Adam learning rate.
learning_rate: 1e-5
# Number of training steps between loss reporting.
steps_per_report: 10
# Number of training steps between validations.
steps_per_eval: 200
# Load path to resume training with the given adapter weights.
resume_adapter_file: null
# Save/load path for the trained adapter weights.
adapter_file: "adapters.npz"
# Save the model every N iterations.
save_every: 100
# Evaluate on the test set after training
test: false
# Number of test set batches, -1 uses the entire test set.
test_batches: 500
# Maximum sequence length.
max_seq_length: 2048

View File

@ -1,22 +1,61 @@
import argparse
import json
import math
import re
import types
from pathlib import Path
import mlx.optimizers as optim
import numpy as np
import yaml
from mlx.utils import tree_flatten
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .utils import load
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"data": "data/",
"seed": 0,
"lora_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
"learning_rate": 1e-5,
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_file": "adapters.npz",
"save_every": 100,
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
}
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
# Generation args
@ -24,18 +63,14 @@ def build_parser():
"--max-tokens",
"-m",
type=int,
default=100,
help="The maximum number of tokens to generate",
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
)
parser.add_argument("--temp", type=float, help="The sampling temperature")
parser.add_argument(
"--prompt",
"-p",
type=str,
help="The prompt for generation",
default=None,
)
# Training args
@ -47,56 +82,44 @@ def build_parser():
parser.add_argument(
"--data",
type=str,
default="data/",
help="Directory with {train, valid, test}.jsonl files",
)
parser.add_argument(
"--lora-layers",
type=int,
default=16,
help="Number of layers to fine-tune",
)
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
parser.add_argument(
"--iters", type=int, default=1000, help="Iterations to train for."
)
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument("--iters", type=int, help="Iterations to train for.")
parser.add_argument(
"--val-batches",
type=int,
default=25,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument(
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
)
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
parser.add_argument(
"--steps-per-report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps-per-eval",
type=int,
default=200,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume-adapter-file",
type=str,
default=None,
help="Load path to resume training with the given adapter weights.",
)
parser.add_argument(
"--adapter-file",
type=str,
default="adapters.npz",
help="Save/load path for the trained adapter weights.",
)
parser.add_argument(
"--save-every",
type=int,
default=100,
help="Save the model every N iterations.",
)
parser.add_argument(
@ -107,16 +130,20 @@ def build_parser():
parser.add_argument(
"--test-batches",
type=int,
default=500,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=2048,
help="Maximum sequence length.",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"-c",
"--config",
default=None,
help="A YAML configuration file with the training options",
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser
@ -242,5 +269,19 @@ def run(args, training_callback: TrainingCallback = None):
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)
with open(config, "r") as file:
config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments
for k, v in config.items():
if not args.get(k, None):
args[k] = v
run(args)
# Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items():
if not args.get(k, None):
args[k] = v
run(types.SimpleNamespace(**args))