mlx-examples/llms/mlx_lm/lora.py
2025-03-05 15:28:12 +01:00

629 lines
19 KiB
Python

# Copyright © 2024 Apple Inc.
from pathlib import Path
import argparse
import types
import math
import os
import re
import mlx.optimizers as optim
import mlx.nn as nn
import numpy as np
import yaml
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
load_adapters,
print_trainable_parameters,
)
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"training_mode": "normal",
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
"num_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
"learning_rate": 1e-5,
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_path": "adapters",
"save_every": 100,
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"mask_prompt": False,
# GRPO args
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
"temperature": 1.0,
"reward_weights": None
}
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
)
# Training args
parser.add_argument(
"--train",
action="store_true",
help="Do training",
default=None,
)
parser.add_argument(
"--data",
type=str,
help=(
"Directory with {train, valid, test}.jsonl files or the name "
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
),
)
parser.add_argument(
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=None,
)
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "grpo"],
help="Training mode: normal or GRPO",
)
parser.add_argument(
"--num-layers",
type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument("--iters", type=int, help="Iterations to train for.")
parser.add_argument(
"--val-batches",
type=int,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
parser.add_argument(
"--steps-per-report",
type=int,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps-per-eval",
type=int,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training from the given fine-tuned weights.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Save/load path for the fine-tuned weights.",
)
parser.add_argument(
"--save-every",
type=int,
help="Save the model every N iterations.",
)
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
default=None,
)
parser.add_argument(
"--test-batches",
type=int,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument(
"--max-seq-length",
type=int,
help="Maximum sequence length.",
)
parser.add_argument(
"-c",
"--config",
type=str,
help="A YAML configuration file with the training options",
)
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
# GRPO args
parser.add_argument(
"--group-size",
type=int,
help="Number of generations.",
default=4,
)
parser.add_argument(
"--max-completion-length",
type=int,
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
default=512,
)
parser.add_argument(
"--beta",
type=float,
help="KL penalty coefficient.",
default=0.1,
)
parser.add_argument(
"--epsilon",
type=float,
help="The Epsilon for numerical stability.",
default=1e-4,
)
parser.add_argument(
"--use-chat-template",
action="store_true",
help="If the model is a Chat model, use the Chat template.",
default=None,
)
parser.add_argument(
"--use-prompt",
action="store_true",
help="Rather to use the prompt from the R1 paper.",
default=None,
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for sampling. The higher the temperature, the more random the completions.",
default=1.0,
)
parser.add_argument(
"--reward-weights",
type=str,
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
default=None,
)
return parser
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path,
temperature=args.temperature,
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_grpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
def train_model(
args,
model: nn.Module,
tokenizer: TokenizerWrapper,
train_set,
valid_set,
training_callback: TrainingCallback = None,
):
model.freeze()
if args.num_layers > len(model.layers):
raise ValueError(
f"Requested to train {args.num_layers} layers "
f"but the model only has {len(model.layers)} layers."
)
if args.fine_tune_type == "full":
for l in model.layers[-max(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers(
model,
args.num_layers,
args.lora_parameters,
use_dora=(args.fine_tune_type == "dora"),
)
else:
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
# Resume from weights if provided
if args.resume_adapter_file is not None:
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
if args.training_mode == "grpo":
train_model_grpo(
model,
tokenizer,
args,
opt,
train_set,
valid_set,
adapter_file,
training_callback
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
if args.training_mode == "grpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
test_loss, _, test_rewards = evaluate_grpo(
model=model,
ref_model=reference_model.freeze(),
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
temperature=args.temperature,
max_tokens=args.max_seq_length
)
test_ppl = math.exp(test_loss)
rewards_str = ", ".join([f"{k}: {v:.3f}" for k, v in test_rewards.items()])
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {rewards_str}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
load_adapters(model, args.adapter_path)
elif args.train:
print("Training")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
if args.test:
print("Testing")
evaluate_model(args, model, tokenizer, test_set)
def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser()
args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)
with open(config, "r") as file:
config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments
for k, v in config.items():
if args.get(k, None) is None:
args[k] = v
# Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items():
if args.get(k, None) is None:
args[k] = v
run(types.SimpleNamespace(**args))
if __name__ == "__main__":
main()
def compute_grpo_loss_and_grad(
model,
ref_model,
completion_tensors,
prompt_texts,
answer_texts,
beta=0.1,
epsilon=1e-4,
reward_funcs=None,
reward_weights=None
):
"""
Compute GRPO loss and gradients using pre-generated completions.
Args:
model: The policy model
ref_model: The reference model
completion_tensors: List of tensors containing generated completions
prompt_texts: List of prompt texts
answer_texts: List of answer texts
beta: KL penalty coefficient
epsilon: Numerical stability constant
reward_funcs: List of reward functions
reward_weights: Optional weights for reward functions
"""
# Ensure model is in training mode for gradient computation
model.train()
# Get completion texts for reward calculation
completion_texts = [tokenizer.decode(comp.tolist()) for comp in completion_tensors]
# Prepare inputs for loss computation
max_length = max(tensor.shape[0] for tensor in completion_tensors)
padded_completions = []
attention_masks = []
for completion_tensor in completion_tensors:
padding_length = max_length - completion_tensor.shape[0]
if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
padded_ids = mx.concatenate([completion_tensor, padding])
mask = mx.concatenate(
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
)
else:
padded_ids = completion_tensor
mask = mx.ones_like(completion_tensor)
padded_completions.append(padded_ids)
attention_masks.append(mask)
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
# Compute log probabilities for both models
token_log_probs = get_per_token_logps(model, inputs, lengths)
if ref_model is None:
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in token_log_probs]
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in ref_token_log_probs]
# Pad log probabilities to same length
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
padded_ref_log_probs = []
for i in range(len(token_log_probs)):
seq_len = token_log_probs[i].shape[0]
padding = mx.zeros((max_len - seq_len,))
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Calculate rewards
all_func_rewards = []
for reward_func in reward_funcs:
func_rewards = mx.array(
reward_func(
prompts=prompt_texts,
completions=completion_texts,
answer=answer_texts,
)
)
all_func_rewards.append(func_rewards)
# Stack rewards and apply weights
rewards = mx.stack(all_func_rewards, axis=1)
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
# Group rewards by prompt (assuming completions are grouped by prompt)
group_size = len(completion_tensors) // len(prompt_texts)
if len(completion_tensors) % len(prompt_texts) != 0:
raise ValueError("Number of completions must be divisible by number of prompts")
rewards_by_group = []
for i in range(0, len(rewards), group_size):
rewards_by_group.append(rewards[i:i+group_size])
# Calculate advantages
advantages = mx.zeros_like(rewards)
for i, group_rewards in enumerate(rewards_by_group):
if len(group_rewards) > 1: # Only normalize if we have multiple samples
mean_reward = mx.mean(group_rewards)
std_reward = mx.std(group_rewards)
for j in range(group_size):
idx = i * group_size + j
advantages[idx] = (group_rewards[j] - mean_reward) / (std_reward + epsilon)
else:
# If only one sample, advantage is 0
advantages[i * group_size] = 0.0
# Compute KL divergence
kl_div = (
mx.exp(ref_token_log_probs - token_log_probs)
- (ref_token_log_probs - token_log_probs)
- 1
)
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(token_log_probs - ref_token_log_probs)
# Compute per-token loss
per_token_loss = -(
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
)
# Average over tokens
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate metrics for reporting
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
metrics = {
"total_rewards_mean": mx.mean(rewards),
"total_rewards_std": mx.std(rewards),
"kl": mean_kl,
}
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
func_rewards = all_func_rewards[i]
metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
metrics[f"{func_name}_std"] = mx.std(func_rewards)
return loss, sequence_lengths.sum(), metrics