mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
feat: move lora into mlx-lm (#337)
* feat: Add lora and qlora training to mlx-lm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
252
llms/mlx_lm/lora.py
Normal file
252
llms/mlx_lm/lora.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .models import llama, mixtral, phi2
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||
from .utils import generate, load
|
||||
|
||||
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
# Generation args
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
type=str,
|
||||
help="The prompt for generation",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# Training args
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="data/",
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of layers to fine-tune",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--iters", type=int, default=1000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-batches",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Number of validation batches, -1 uses the entire validation set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-report",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-eval",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Load path to resume training with the given adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
type=str,
|
||||
default="adapters.npz",
|
||||
help="Save/load path for the trained adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Save the model every N iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of test set batches, -1 uses the entire test set.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, key: str = "text"):
|
||||
if not path.exists():
|
||||
self._data = None
|
||||
else:
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
self._key = key
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._key]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def load_dataset(args):
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
if model.__class__ not in SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {model.__class__} not supported. "
|
||||
f"Supported models: { SUPPORTED_MODELS}"
|
||||
)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
if hasattr(l, "block_sparse_moe"):
|
||||
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
print(f"Trainable parameters {p:.3f}M")
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args)
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
# init training args
|
||||
trainingArgs = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=args.adapter_file,
|
||||
)
|
||||
if args.train:
|
||||
print("Training")
|
||||
model.train()
|
||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=trainingArgs,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
)
|
||||
|
||||
# Load the LoRA adapter weights which we assume should exist by this point
|
||||
if not Path(args.adapter_file).is_file():
|
||||
raise ValueError(
|
||||
f"Adapter file {args.adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters.npz."
|
||||
)
|
||||
model.load_weights(args.adapter_file, strict=False)
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
if args.prompt is not None:
|
||||
print("Generating")
|
||||
generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
temp=args.temp,
|
||||
max_tokens=args.max_tokens,
|
||||
prompt=args.prompt,
|
||||
verbose=True,
|
||||
)
|
||||
Reference in New Issue
Block a user