2024-03-12 22:37:40 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
import argparse
|
|
|
|
import math
|
2024-03-08 23:57:52 +08:00
|
|
|
import re
|
|
|
|
import types
|
2024-01-24 00:44:37 +08:00
|
|
|
from pathlib import Path
|
|
|
|
|
2024-03-20 23:41:03 +08:00
|
|
|
import mlx.nn as nn
|
2024-01-24 00:44:37 +08:00
|
|
|
import mlx.optimizers as optim
|
|
|
|
import numpy as np
|
2024-03-08 23:57:52 +08:00
|
|
|
import yaml
|
2024-01-24 00:44:37 +08:00
|
|
|
from mlx.utils import tree_flatten
|
|
|
|
|
2024-03-20 07:45:46 +08:00
|
|
|
from .tuner.datasets import load_dataset
|
2024-02-27 11:35:04 +08:00
|
|
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
2024-03-30 04:41:10 +08:00
|
|
|
from .tuner.utils import build_schedule, linear_to_lora_layers
|
2024-04-03 04:52:53 +08:00
|
|
|
from .utils import load, save_config
|
2024-02-06 13:13:49 +08:00
|
|
|
|
2024-03-08 23:57:52 +08:00
|
|
|
yaml_loader = yaml.SafeLoader
|
|
|
|
yaml_loader.add_implicit_resolver(
|
|
|
|
"tag:yaml.org,2002:float",
|
|
|
|
re.compile(
|
|
|
|
"""^(?:
|
|
|
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
|
|
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
|
|
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
|
|
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
|
|
|
|[-+]?\\.(?:inf|Inf|INF)
|
|
|
|
|\\.(?:nan|NaN|NAN))$""",
|
|
|
|
re.X,
|
|
|
|
),
|
|
|
|
list("-+0123456789."),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG_DEFAULTS = {
|
|
|
|
"model": "mlx_model",
|
|
|
|
"train": False,
|
|
|
|
"data": "data/",
|
|
|
|
"seed": 0,
|
|
|
|
"lora_layers": 16,
|
|
|
|
"batch_size": 4,
|
|
|
|
"iters": 1000,
|
|
|
|
"val_batches": 25,
|
|
|
|
"learning_rate": 1e-5,
|
|
|
|
"steps_per_report": 10,
|
|
|
|
"steps_per_eval": 200,
|
|
|
|
"resume_adapter_file": None,
|
2024-04-03 04:52:53 +08:00
|
|
|
"adapter_path": "adapters",
|
2024-03-08 23:57:52 +08:00
|
|
|
"save_every": 100,
|
|
|
|
"test": False,
|
|
|
|
"test_batches": 500,
|
|
|
|
"max_seq_length": 2048,
|
2024-03-30 04:41:10 +08:00
|
|
|
"lr_schedule": None,
|
2024-03-12 22:37:40 +08:00
|
|
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
2024-03-08 23:57:52 +08:00
|
|
|
}
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
def build_parser():
|
|
|
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
help="The path to the local model directory or Hugging Face repo.",
|
|
|
|
)
|
|
|
|
|
|
|
|
# Training args
|
|
|
|
parser.add_argument(
|
|
|
|
"--train",
|
|
|
|
action="store_true",
|
|
|
|
help="Do training",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--data",
|
|
|
|
type=str,
|
|
|
|
help="Directory with {train, valid, test}.jsonl files",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--lora-layers",
|
|
|
|
type=int,
|
|
|
|
help="Number of layers to fine-tune",
|
|
|
|
)
|
2024-03-08 23:57:52 +08:00
|
|
|
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
|
|
|
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
2024-01-24 00:44:37 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--val-batches",
|
|
|
|
type=int,
|
|
|
|
help="Number of validation batches, -1 uses the entire validation set.",
|
|
|
|
)
|
2024-03-08 23:57:52 +08:00
|
|
|
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
2024-01-24 00:44:37 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--steps-per-report",
|
|
|
|
type=int,
|
|
|
|
help="Number of training steps between loss reporting.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--steps-per-eval",
|
|
|
|
type=int,
|
|
|
|
help="Number of training steps between validations.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--resume-adapter-file",
|
|
|
|
type=str,
|
2024-04-03 04:52:53 +08:00
|
|
|
help="Load path to resume training with the given adapters.",
|
2024-01-24 00:44:37 +08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
2024-04-03 04:52:53 +08:00
|
|
|
"--adapter-path",
|
2024-01-24 00:44:37 +08:00
|
|
|
type=str,
|
2024-04-03 04:52:53 +08:00
|
|
|
help="Save/load path for the adapters.",
|
2024-01-24 00:44:37 +08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--save-every",
|
|
|
|
type=int,
|
|
|
|
help="Save the model every N iterations.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--test",
|
|
|
|
action="store_true",
|
|
|
|
help="Evaluate on the test set after training",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--test-batches",
|
|
|
|
type=int,
|
|
|
|
help="Number of test set batches, -1 uses the entire test set.",
|
|
|
|
)
|
2024-02-05 04:28:21 +08:00
|
|
|
parser.add_argument(
|
2024-02-18 22:04:49 +08:00
|
|
|
"--max-seq-length",
|
2024-02-05 04:28:21 +08:00
|
|
|
type=int,
|
|
|
|
help="Maximum sequence length.",
|
|
|
|
)
|
2024-03-08 23:57:52 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"-c",
|
|
|
|
"--config",
|
|
|
|
default=None,
|
|
|
|
help="A YAML configuration file with the training options",
|
|
|
|
)
|
2024-03-13 11:02:03 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--grad-checkpoint",
|
|
|
|
action="store_true",
|
|
|
|
help="Use gradient checkpointing to reduce memory use.",
|
|
|
|
)
|
|
|
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
2024-01-24 00:44:37 +08:00
|
|
|
return parser
|
|
|
|
|
|
|
|
|
2024-03-14 11:26:30 +08:00
|
|
|
def print_trainable_parameters(model):
|
2024-03-20 23:41:03 +08:00
|
|
|
def nparams(m):
|
|
|
|
if isinstance(m, nn.QuantizedLinear):
|
|
|
|
return m.weight.size * (32 // m.bits)
|
|
|
|
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
|
|
|
|
|
|
|
leaf_modules = tree_flatten(
|
|
|
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
|
|
|
)
|
|
|
|
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
2024-03-14 11:26:30 +08:00
|
|
|
trainable_p = (
|
|
|
|
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
|
|
|
)
|
|
|
|
print(
|
|
|
|
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
|
|
|
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-02-27 11:35:04 +08:00
|
|
|
def run(args, training_callback: TrainingCallback = None):
|
2024-01-24 00:44:37 +08:00
|
|
|
np.random.seed(args.seed)
|
|
|
|
|
|
|
|
print("Loading pretrained model")
|
|
|
|
model, tokenizer = load(args.model)
|
|
|
|
|
2024-02-13 02:51:02 +08:00
|
|
|
# Freeze all layers
|
2024-01-24 00:44:37 +08:00
|
|
|
model.freeze()
|
2024-02-13 02:51:02 +08:00
|
|
|
# Convert linear layers to lora layers and unfreeze in the process
|
2024-03-12 22:37:40 +08:00
|
|
|
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
2024-03-14 11:26:30 +08:00
|
|
|
print_trainable_parameters(model)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
print("Loading datasets")
|
2024-03-20 07:45:46 +08:00
|
|
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
# Resume training the given adapters.
|
|
|
|
if args.resume_adapter_file is not None:
|
|
|
|
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
|
|
|
model.load_weights(args.resume_adapter_file, strict=False)
|
2024-03-14 11:26:30 +08:00
|
|
|
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_path = Path(args.adapter_path)
|
|
|
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
|
|
|
adapter_file = adapter_path / "adapters.safetensors"
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
if args.train:
|
|
|
|
print("Training")
|
2024-03-14 11:26:30 +08:00
|
|
|
# init training args
|
|
|
|
training_args = TrainingArgs(
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
iters=args.iters,
|
|
|
|
val_batches=args.val_batches,
|
|
|
|
steps_per_report=args.steps_per_report,
|
|
|
|
steps_per_eval=args.steps_per_eval,
|
|
|
|
steps_per_save=args.save_every,
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_file=adapter_file,
|
2024-03-14 11:26:30 +08:00
|
|
|
max_seq_length=args.max_seq_length,
|
|
|
|
grad_checkpoint=args.grad_checkpoint,
|
|
|
|
)
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
model.train()
|
2024-03-30 04:41:10 +08:00
|
|
|
opt = optim.Adam(
|
|
|
|
learning_rate=(
|
|
|
|
build_schedule(args.lr_schedule)
|
|
|
|
if args.lr_schedule
|
|
|
|
else args.learning_rate
|
|
|
|
)
|
|
|
|
)
|
2024-01-24 00:44:37 +08:00
|
|
|
# Train model
|
|
|
|
train(
|
|
|
|
model=model,
|
|
|
|
tokenizer=tokenizer,
|
2024-03-14 11:26:30 +08:00
|
|
|
args=training_args,
|
2024-01-24 00:44:37 +08:00
|
|
|
optimizer=opt,
|
|
|
|
train_dataset=train_set,
|
|
|
|
val_dataset=valid_set,
|
2024-02-27 11:35:04 +08:00
|
|
|
training_callback=training_callback,
|
2024-01-24 00:44:37 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# Load the LoRA adapter weights which we assume should exist by this point
|
2024-04-03 04:52:53 +08:00
|
|
|
if not adapter_file.is_file():
|
2024-01-24 00:44:37 +08:00
|
|
|
raise ValueError(
|
2024-04-03 04:52:53 +08:00
|
|
|
f"Adapter file {adapter_file} missing. "
|
|
|
|
"Use --train to learn and save the adapters"
|
2024-01-24 00:44:37 +08:00
|
|
|
)
|
2024-04-03 04:52:53 +08:00
|
|
|
model.load_weights(str(adapter_file), strict=False)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
if args.test:
|
|
|
|
print("Testing")
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
test_loss = evaluate(
|
|
|
|
model=model,
|
|
|
|
dataset=test_set,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
num_batches=args.test_batches,
|
|
|
|
)
|
|
|
|
|
|
|
|
test_ppl = math.exp(test_loss)
|
|
|
|
|
|
|
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
|
|
|
|
2024-02-27 11:35:04 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = build_parser()
|
|
|
|
args = parser.parse_args()
|
2024-03-08 23:57:52 +08:00
|
|
|
config = args.config
|
|
|
|
args = vars(args)
|
|
|
|
if config:
|
|
|
|
print("Loading configuration file", config)
|
|
|
|
with open(config, "r") as file:
|
|
|
|
config = yaml.load(file, yaml_loader)
|
|
|
|
# Prefer parameters from command-line arguments
|
|
|
|
for k, v in config.items():
|
|
|
|
if not args.get(k, None):
|
|
|
|
args[k] = v
|
2024-02-27 11:35:04 +08:00
|
|
|
|
2024-03-08 23:57:52 +08:00
|
|
|
# Update defaults for unspecified parameters
|
|
|
|
for k, v in CONFIG_DEFAULTS.items():
|
|
|
|
if not args.get(k, None):
|
|
|
|
args[k] = v
|
|
|
|
run(types.SimpleNamespace(**args))
|