YAML configuration for mlx_lm.lora (#503)

* Convert mlx_lm.lora to use YAML configuration

* pre-commit run fixes

* Fix loading of config file

* Remove invalid YAML from doc

* Update command-line options and YAML parameter overriding, per feedback in #503

* Minor wording change

* Positional argument

* Moved config to a (-c/--config) flag

* Removed CLI option defaults (since CLI options take precedence and their defaults are in CONFIG_DEFAULTS)

* pre-commit format updates

* Fix handling of CLI option defaults

* Prevent None values of unspecified CLI options from overwriting values from CONFIG_DEFAULTS

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-03-08 10:57:52 -05:00 committed by GitHub
parent 8b05bb6d18
commit 8c2cf665ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 129 additions and 28 deletions

View File

@ -23,14 +23,24 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
## Run ## Run
The main command is `mlx_lm.lora`. To see a full list of options run: The main command is `mlx_lm.lora`. To see a full list of command-line options run:
```shell ```shell
python -m mlx_lm.lora --help python -m mlx_lm.lora --help
``` ```
Note, in the following the `--model` argument can be any compatible Hugging Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted model. Face repo or a local path to a converted model.
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
[example YAML](examples/lora_config.yaml). For example:
```shell
python -m mlx_lm.lora --config /path/to/config.yaml
```
If command-line flags are also used, they will override the corresponding
values in the config.
### Fine-tune ### Fine-tune
@ -74,7 +84,7 @@ python -m mlx_lm.lora \
### Generate ### Generate
For generation use mlx_lm.generate: For generation use `mlx_lm.generate`:
```shell ```shell
python -m mlx_lm.generate \ python -m mlx_lm.generate \

View File

@ -0,0 +1,50 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
# The PRNG seed
seed: 0
# Number of layers to fine-tune
lora_layers: 16
# Minibatch size.
batch_size: 4
# Iterations to train for.
iters: 100
# Number of validation batches, -1 uses the entire validation set.
val_batches: 25
# Adam learning rate.
learning_rate: 1e-5
# Number of training steps between loss reporting.
steps_per_report: 10
# Number of training steps between validations.
steps_per_eval: 200
# Load path to resume training with the given adapter weights.
resume_adapter_file: null
# Save/load path for the trained adapter weights.
adapter_file: "adapters.npz"
# Save the model every N iterations.
save_every: 100
# Evaluate on the test set after training
test: false
# Number of test set batches, -1 uses the entire test set.
test_batches: 500
# Maximum sequence length.
max_seq_length: 2048

View File

@ -1,22 +1,61 @@
import argparse import argparse
import json import json
import math import math
import re
import types
from pathlib import Path from pathlib import Path
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
import yaml
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers from .tuner.utils import linear_to_lora_layers
from .utils import load from .utils import load
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"data": "data/",
"seed": 0,
"lora_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
"learning_rate": 1e-5,
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_file": "adapters.npz",
"save_every": 100,
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
}
def build_parser(): def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
# Generation args # Generation args
@ -24,18 +63,14 @@ def build_parser():
"--max-tokens", "--max-tokens",
"-m", "-m",
type=int, type=int,
default=100,
help="The maximum number of tokens to generate", help="The maximum number of tokens to generate",
) )
parser.add_argument( parser.add_argument("--temp", type=float, help="The sampling temperature")
"--temp", type=float, default=0.8, help="The sampling temperature"
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p", "-p",
type=str, type=str,
help="The prompt for generation", help="The prompt for generation",
default=None,
) )
# Training args # Training args
@ -47,56 +82,44 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--data", "--data",
type=str, type=str,
default="data/",
help="Directory with {train, valid, test}.jsonl files", help="Directory with {train, valid, test}.jsonl files",
) )
parser.add_argument( parser.add_argument(
"--lora-layers", "--lora-layers",
type=int, type=int,
default=16,
help="Number of layers to fine-tune", help="Number of layers to fine-tune",
) )
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.") parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument( parser.add_argument("--iters", type=int, help="Iterations to train for.")
"--iters", type=int, default=1000, help="Iterations to train for."
)
parser.add_argument( parser.add_argument(
"--val-batches", "--val-batches",
type=int, type=int,
default=25,
help="Number of validation batches, -1 uses the entire validation set.", help="Number of validation batches, -1 uses the entire validation set.",
) )
parser.add_argument( parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
)
parser.add_argument( parser.add_argument(
"--steps-per-report", "--steps-per-report",
type=int, type=int,
default=10,
help="Number of training steps between loss reporting.", help="Number of training steps between loss reporting.",
) )
parser.add_argument( parser.add_argument(
"--steps-per-eval", "--steps-per-eval",
type=int, type=int,
default=200,
help="Number of training steps between validations.", help="Number of training steps between validations.",
) )
parser.add_argument( parser.add_argument(
"--resume-adapter-file", "--resume-adapter-file",
type=str, type=str,
default=None,
help="Load path to resume training with the given adapter weights.", help="Load path to resume training with the given adapter weights.",
) )
parser.add_argument( parser.add_argument(
"--adapter-file", "--adapter-file",
type=str, type=str,
default="adapters.npz",
help="Save/load path for the trained adapter weights.", help="Save/load path for the trained adapter weights.",
) )
parser.add_argument( parser.add_argument(
"--save-every", "--save-every",
type=int, type=int,
default=100,
help="Save the model every N iterations.", help="Save the model every N iterations.",
) )
parser.add_argument( parser.add_argument(
@ -107,16 +130,20 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--test-batches", "--test-batches",
type=int, type=int,
default=500,
help="Number of test set batches, -1 uses the entire test set.", help="Number of test set batches, -1 uses the entire test set.",
) )
parser.add_argument( parser.add_argument(
"--max-seq-length", "--max-seq-length",
type=int, type=int,
default=2048,
help="Maximum sequence length.", help="Maximum sequence length.",
) )
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument(
"-c",
"--config",
default=None,
help="A YAML configuration file with the training options",
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser return parser
@ -242,5 +269,19 @@ def run(args, training_callback: TrainingCallback = None):
if __name__ == "__main__": if __name__ == "__main__":
parser = build_parser() parser = build_parser()
args = parser.parse_args() args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)
with open(config, "r") as file:
config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments
for k, v in config.items():
if not args.get(k, None):
args[k] = v
run(args) # Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items():
if not args.get(k, None):
args[k] = v
run(types.SimpleNamespace(**args))