mlx-examples/llms/mlx_lm/lora.py

279 lines
7.6 KiB
Python
Raw Normal View History

# Copyright © 2024 Apple Inc.
import argparse
import math
import re
import types
from pathlib import Path
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import yaml
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
apply_lora_layers,
build_schedule,
linear_to_lora_layers,
print_trainable_parameters,
)
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"data": "data/",
"seed": 0,
"lora_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
"learning_rate": 1e-5,
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_path": "adapters",
"save_every": 100,
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"use_dora": False,
}
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
help="The path to the local model directory or Hugging Face repo.",
)
# Training args
parser.add_argument(
"--train",
action="store_true",
help="Do training",
default=None,
)
parser.add_argument(
"--data",
type=str,
help="Directory with {train, valid, test}.jsonl files",
)
parser.add_argument(
"--lora-layers",
type=int,
2024-05-22 11:09:35 +08:00
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument("--iters", type=int, help="Iterations to train for.")
parser.add_argument(
"--val-batches",
type=int,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
parser.add_argument(
"--steps-per-report",
type=int,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps-per-eval",
type=int,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training with the given adapters.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Save/load path for the adapters.",
)
parser.add_argument(
"--save-every",
type=int,
help="Save the model every N iterations.",
)
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
default=None,
)
parser.add_argument(
"--test-batches",
type=int,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument(
"--max-seq-length",
type=int,
help="Maximum sequence length.",
)
parser.add_argument(
"-c",
"--config",
default=None,
help="A YAML configuration file with the training options",
)
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
2024-05-22 11:09:35 +08:00
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument(
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
)
return parser
def train_model(
args,
model: nn.Module,
tokenizer: TokenizerWrapper,
train_set,
valid_set,
training_callback: TrainingCallback = None,
):
# Freeze all layers
model.freeze()
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
apply_lora_layers(model, args.adapter_path)
elif args.train:
print("Training")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
if args.test:
print("Testing")
evaluate_model(args, model, tokenizer, test_set)
def main():
parser = build_parser()
args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)
with open(config, "r") as file:
config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments
for k, v in config.items():
2024-05-22 11:09:35 +08:00
if args.get(k, None) is None:
args[k] = v
# Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items():
if args.get(k, None) is None:
args[k] = v
run(types.SimpleNamespace(**args))
if __name__ == "__main__":
main()