mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 10:26:38 +08:00
YAML configuration for mlx_lm.lora (#503)
* Convert mlx_lm.lora to use YAML configuration * pre-commit run fixes * Fix loading of config file * Remove invalid YAML from doc * Update command-line options and YAML parameter overriding, per feedback in #503 * Minor wording change * Positional argument * Moved config to a (-c/--config) flag * Removed CLI option defaults (since CLI options take precedence and their defaults are in CONFIG_DEFAULTS) * pre-commit format updates * Fix handling of CLI option defaults * Prevent None values of unspecified CLI options from overwriting values from CONFIG_DEFAULTS * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
8b05bb6d18
commit
8c2cf665ed
@ -23,7 +23,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
|
||||
## Run
|
||||
|
||||
The main command is `mlx_lm.lora`. To see a full list of options run:
|
||||
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.lora --help
|
||||
@ -32,6 +32,16 @@ python -m mlx_lm.lora --help
|
||||
Note, in the following the `--model` argument can be any compatible Hugging
|
||||
Face repo or a local path to a converted model.
|
||||
|
||||
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
|
||||
[example YAML](examples/lora_config.yaml). For example:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.lora --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
If command-line flags are also used, they will override the corresponding
|
||||
values in the config.
|
||||
|
||||
### Fine-tune
|
||||
|
||||
To fine-tune a model use:
|
||||
@ -74,7 +84,7 @@ python -m mlx_lm.lora \
|
||||
|
||||
### Generate
|
||||
|
||||
For generation use mlx_lm.generate:
|
||||
For generation use `mlx_lm.generate`:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.generate \
|
||||
|
50
llms/mlx_lm/examples/lora_config.yaml
Normal file
50
llms/mlx_lm/examples/lora_config.yaml
Normal file
@ -0,0 +1,50 @@
|
||||
# The path to the local model directory or Hugging Face repo.
|
||||
model: "mlx_model"
|
||||
|
||||
# Whether or not to train (boolean)
|
||||
train: true
|
||||
|
||||
# Directory with {train, valid, test}.jsonl files
|
||||
data: "/path/to/training/data"
|
||||
|
||||
# The PRNG seed
|
||||
seed: 0
|
||||
|
||||
# Number of layers to fine-tune
|
||||
lora_layers: 16
|
||||
|
||||
# Minibatch size.
|
||||
batch_size: 4
|
||||
|
||||
# Iterations to train for.
|
||||
iters: 100
|
||||
|
||||
# Number of validation batches, -1 uses the entire validation set.
|
||||
val_batches: 25
|
||||
|
||||
# Adam learning rate.
|
||||
learning_rate: 1e-5
|
||||
|
||||
# Number of training steps between loss reporting.
|
||||
steps_per_report: 10
|
||||
|
||||
# Number of training steps between validations.
|
||||
steps_per_eval: 200
|
||||
|
||||
# Load path to resume training with the given adapter weights.
|
||||
resume_adapter_file: null
|
||||
|
||||
# Save/load path for the trained adapter weights.
|
||||
adapter_file: "adapters.npz"
|
||||
|
||||
# Save the model every N iterations.
|
||||
save_every: 100
|
||||
|
||||
# Evaluate on the test set after training
|
||||
test: false
|
||||
|
||||
# Number of test set batches, -1 uses the entire test set.
|
||||
test_batches: 500
|
||||
|
||||
# Maximum sequence length.
|
||||
max_seq_length: 2048
|
@ -1,22 +1,61 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .utils import load
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
yaml_loader.add_implicit_resolver(
|
||||
"tag:yaml.org,2002:float",
|
||||
re.compile(
|
||||
"""^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$""",
|
||||
re.X,
|
||||
),
|
||||
list("-+0123456789."),
|
||||
)
|
||||
|
||||
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"lora_layers": 16,
|
||||
"batch_size": 4,
|
||||
"iters": 1000,
|
||||
"val_batches": 25,
|
||||
"learning_rate": 1e-5,
|
||||
"steps_per_report": 10,
|
||||
"steps_per_eval": 200,
|
||||
"resume_adapter_file": None,
|
||||
"adapter_file": "adapters.npz",
|
||||
"save_every": 100,
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
"max_seq_length": 2048,
|
||||
}
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
# Generation args
|
||||
@ -24,18 +63,14 @@ def build_parser():
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument("--temp", type=float, help="The sampling temperature")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
type=str,
|
||||
help="The prompt for generation",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# Training args
|
||||
@ -47,56 +82,44 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="data/",
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of layers to fine-tune",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--iters", type=int, default=1000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
||||
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
||||
parser.add_argument(
|
||||
"--val-batches",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Number of validation batches, -1 uses the entire validation set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
|
||||
)
|
||||
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
||||
parser.add_argument(
|
||||
"--steps-per-report",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-eval",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Load path to resume training with the given adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
type=str,
|
||||
default="adapters.npz",
|
||||
help="Save/load path for the trained adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Save the model every N iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -107,16 +130,20 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of test set batches, -1 uses the entire test set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-seq-length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum sequence length.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
default=None,
|
||||
help="A YAML configuration file with the training options",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
@ -242,5 +269,19 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
if __name__ == "__main__":
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
config = args.config
|
||||
args = vars(args)
|
||||
if config:
|
||||
print("Loading configuration file", config)
|
||||
with open(config, "r") as file:
|
||||
config = yaml.load(file, yaml_loader)
|
||||
# Prefer parameters from command-line arguments
|
||||
for k, v in config.items():
|
||||
if not args.get(k, None):
|
||||
args[k] = v
|
||||
|
||||
run(args)
|
||||
# Update defaults for unspecified parameters
|
||||
for k, v in CONFIG_DEFAULTS.items():
|
||||
if not args.get(k, None):
|
||||
args[k] = v
|
||||
run(types.SimpleNamespace(**args))
|
||||
|
Loading…
Reference in New Issue
Block a user