mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 10:26:38 +08:00
YAML configuration for mlx_lm.lora (#503)
* Convert mlx_lm.lora to use YAML configuration * pre-commit run fixes * Fix loading of config file * Remove invalid YAML from doc * Update command-line options and YAML parameter overriding, per feedback in #503 * Minor wording change * Positional argument * Moved config to a (-c/--config) flag * Removed CLI option defaults (since CLI options take precedence and their defaults are in CONFIG_DEFAULTS) * pre-commit format updates * Fix handling of CLI option defaults * Prevent None values of unspecified CLI options from overwriting values from CONFIG_DEFAULTS * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
8b05bb6d18
commit
8c2cf665ed
@ -23,14 +23,24 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
|||||||
|
|
||||||
## Run
|
## Run
|
||||||
|
|
||||||
The main command is `mlx_lm.lora`. To see a full list of options run:
|
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python -m mlx_lm.lora --help
|
python -m mlx_lm.lora --help
|
||||||
```
|
```
|
||||||
|
|
||||||
Note, in the following the `--model` argument can be any compatible Hugging
|
Note, in the following the `--model` argument can be any compatible Hugging
|
||||||
Face repo or a local path to a converted model.
|
Face repo or a local path to a converted model.
|
||||||
|
|
||||||
|
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
|
||||||
|
[example YAML](examples/lora_config.yaml). For example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.lora --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
If command-line flags are also used, they will override the corresponding
|
||||||
|
values in the config.
|
||||||
|
|
||||||
### Fine-tune
|
### Fine-tune
|
||||||
|
|
||||||
@ -74,7 +84,7 @@ python -m mlx_lm.lora \
|
|||||||
|
|
||||||
### Generate
|
### Generate
|
||||||
|
|
||||||
For generation use mlx_lm.generate:
|
For generation use `mlx_lm.generate`:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python -m mlx_lm.generate \
|
python -m mlx_lm.generate \
|
||||||
|
50
llms/mlx_lm/examples/lora_config.yaml
Normal file
50
llms/mlx_lm/examples/lora_config.yaml
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# The path to the local model directory or Hugging Face repo.
|
||||||
|
model: "mlx_model"
|
||||||
|
|
||||||
|
# Whether or not to train (boolean)
|
||||||
|
train: true
|
||||||
|
|
||||||
|
# Directory with {train, valid, test}.jsonl files
|
||||||
|
data: "/path/to/training/data"
|
||||||
|
|
||||||
|
# The PRNG seed
|
||||||
|
seed: 0
|
||||||
|
|
||||||
|
# Number of layers to fine-tune
|
||||||
|
lora_layers: 16
|
||||||
|
|
||||||
|
# Minibatch size.
|
||||||
|
batch_size: 4
|
||||||
|
|
||||||
|
# Iterations to train for.
|
||||||
|
iters: 100
|
||||||
|
|
||||||
|
# Number of validation batches, -1 uses the entire validation set.
|
||||||
|
val_batches: 25
|
||||||
|
|
||||||
|
# Adam learning rate.
|
||||||
|
learning_rate: 1e-5
|
||||||
|
|
||||||
|
# Number of training steps between loss reporting.
|
||||||
|
steps_per_report: 10
|
||||||
|
|
||||||
|
# Number of training steps between validations.
|
||||||
|
steps_per_eval: 200
|
||||||
|
|
||||||
|
# Load path to resume training with the given adapter weights.
|
||||||
|
resume_adapter_file: null
|
||||||
|
|
||||||
|
# Save/load path for the trained adapter weights.
|
||||||
|
adapter_file: "adapters.npz"
|
||||||
|
|
||||||
|
# Save the model every N iterations.
|
||||||
|
save_every: 100
|
||||||
|
|
||||||
|
# Evaluate on the test set after training
|
||||||
|
test: false
|
||||||
|
|
||||||
|
# Number of test set batches, -1 uses the entire test set.
|
||||||
|
test_batches: 500
|
||||||
|
|
||||||
|
# Maximum sequence length.
|
||||||
|
max_seq_length: 2048
|
@ -1,22 +1,61 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import yaml
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import linear_to_lora_layers
|
from .tuner.utils import linear_to_lora_layers
|
||||||
from .utils import load
|
from .utils import load
|
||||||
|
|
||||||
|
yaml_loader = yaml.SafeLoader
|
||||||
|
yaml_loader.add_implicit_resolver(
|
||||||
|
"tag:yaml.org,2002:float",
|
||||||
|
re.compile(
|
||||||
|
"""^(?:
|
||||||
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||||
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||||
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||||
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||||
|
|[-+]?\\.(?:inf|Inf|INF)
|
||||||
|
|\\.(?:nan|NaN|NAN))$""",
|
||||||
|
re.X,
|
||||||
|
),
|
||||||
|
list("-+0123456789."),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_DEFAULTS = {
|
||||||
|
"model": "mlx_model",
|
||||||
|
"train": False,
|
||||||
|
"data": "data/",
|
||||||
|
"seed": 0,
|
||||||
|
"lora_layers": 16,
|
||||||
|
"batch_size": 4,
|
||||||
|
"iters": 1000,
|
||||||
|
"val_batches": 25,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"steps_per_report": 10,
|
||||||
|
"steps_per_eval": 200,
|
||||||
|
"resume_adapter_file": None,
|
||||||
|
"adapter_file": "adapters.npz",
|
||||||
|
"save_every": 100,
|
||||||
|
"test": False,
|
||||||
|
"test_batches": 500,
|
||||||
|
"max_seq_length": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_parser():
|
def build_parser():
|
||||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="mlx_model",
|
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
# Generation args
|
# Generation args
|
||||||
@ -24,18 +63,14 @@ def build_parser():
|
|||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
"-m",
|
"-m",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
|
||||||
help="The maximum number of tokens to generate",
|
help="The maximum number of tokens to generate",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--temp", type=float, help="The sampling temperature")
|
||||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
type=str,
|
type=str,
|
||||||
help="The prompt for generation",
|
help="The prompt for generation",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training args
|
# Training args
|
||||||
@ -47,56 +82,44 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/",
|
|
||||||
help="Directory with {train, valid, test}.jsonl files",
|
help="Directory with {train, valid, test}.jsonl files",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-layers",
|
"--lora-layers",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
|
||||||
help="Number of layers to fine-tune",
|
help="Number of layers to fine-tune",
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
|
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
||||||
parser.add_argument(
|
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
||||||
"--iters", type=int, default=1000, help="Iterations to train for."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--val-batches",
|
"--val-batches",
|
||||||
type=int,
|
type=int,
|
||||||
default=25,
|
|
||||||
help="Number of validation batches, -1 uses the entire validation set.",
|
help="Number of validation batches, -1 uses the entire validation set.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
||||||
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--steps-per-report",
|
"--steps-per-report",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
|
||||||
help="Number of training steps between loss reporting.",
|
help="Number of training steps between loss reporting.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--steps-per-eval",
|
"--steps-per-eval",
|
||||||
type=int,
|
type=int,
|
||||||
default=200,
|
|
||||||
help="Number of training steps between validations.",
|
help="Number of training steps between validations.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume-adapter-file",
|
"--resume-adapter-file",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
|
||||||
help="Load path to resume training with the given adapter weights.",
|
help="Load path to resume training with the given adapter weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-file",
|
"--adapter-file",
|
||||||
type=str,
|
type=str,
|
||||||
default="adapters.npz",
|
|
||||||
help="Save/load path for the trained adapter weights.",
|
help="Save/load path for the trained adapter weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every",
|
"--save-every",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
|
||||||
help="Save the model every N iterations.",
|
help="Save the model every N iterations.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -107,16 +130,20 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-batches",
|
"--test-batches",
|
||||||
type=int,
|
type=int,
|
||||||
default=500,
|
|
||||||
help="Number of test set batches, -1 uses the entire test set.",
|
help="Number of test set batches, -1 uses the entire test set.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-seq-length",
|
"--max-seq-length",
|
||||||
type=int,
|
type=int,
|
||||||
default=2048,
|
|
||||||
help="Maximum sequence length.",
|
help="Maximum sequence length.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--config",
|
||||||
|
default=None,
|
||||||
|
help="A YAML configuration file with the training options",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -242,5 +269,19 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = build_parser()
|
parser = build_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
config = args.config
|
||||||
|
args = vars(args)
|
||||||
|
if config:
|
||||||
|
print("Loading configuration file", config)
|
||||||
|
with open(config, "r") as file:
|
||||||
|
config = yaml.load(file, yaml_loader)
|
||||||
|
# Prefer parameters from command-line arguments
|
||||||
|
for k, v in config.items():
|
||||||
|
if not args.get(k, None):
|
||||||
|
args[k] = v
|
||||||
|
|
||||||
run(args)
|
# Update defaults for unspecified parameters
|
||||||
|
for k, v in CONFIG_DEFAULTS.items():
|
||||||
|
if not args.get(k, None):
|
||||||
|
args[k] = v
|
||||||
|
run(types.SimpleNamespace(**args))
|
||||||
|
Loading…
Reference in New Issue
Block a user