mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
629 lines
19 KiB
Python
629 lines
19 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
from pathlib import Path
|
|
import argparse
|
|
import types
|
|
import math
|
|
import os
|
|
import re
|
|
|
|
import mlx.optimizers as optim
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
import yaml
|
|
|
|
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
|
from .tokenizer_utils import TokenizerWrapper
|
|
from .tuner.datasets import load_dataset
|
|
from .tuner.utils import (
|
|
build_schedule,
|
|
linear_to_lora_layers,
|
|
load_adapters,
|
|
print_trainable_parameters,
|
|
)
|
|
from .utils import load, save_config
|
|
|
|
yaml_loader = yaml.SafeLoader
|
|
yaml_loader.add_implicit_resolver(
|
|
"tag:yaml.org,2002:float",
|
|
re.compile(
|
|
"""^(?:
|
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
|
|[-+]?\\.(?:inf|Inf|INF)
|
|
|\\.(?:nan|NaN|NAN))$""",
|
|
re.X,
|
|
),
|
|
list("-+0123456789."),
|
|
)
|
|
|
|
CONFIG_DEFAULTS = {
|
|
"model": "mlx_model",
|
|
"train": False,
|
|
"training_mode": "normal",
|
|
"fine_tune_type": "lora",
|
|
"data": "data/",
|
|
"seed": 0,
|
|
"num_layers": 16,
|
|
"batch_size": 4,
|
|
"iters": 1000,
|
|
"val_batches": 25,
|
|
"learning_rate": 1e-5,
|
|
"steps_per_report": 10,
|
|
"steps_per_eval": 200,
|
|
"resume_adapter_file": None,
|
|
"adapter_path": "adapters",
|
|
"save_every": 100,
|
|
"test": False,
|
|
"test_batches": 500,
|
|
"max_seq_length": 2048,
|
|
"config": None,
|
|
"grad_checkpoint": False,
|
|
"lr_schedule": None,
|
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
|
"mask_prompt": False,
|
|
|
|
# GRPO args
|
|
"reference_model_path": None,
|
|
"group_size": 4,
|
|
"beta": 0.1,
|
|
"epsilon": 1e-4,
|
|
"max_completion_length": 512,
|
|
"use_chat_template": False,
|
|
"use_prompt": False,
|
|
"temperature": 1.0,
|
|
"reward_weights": None
|
|
}
|
|
|
|
|
|
def build_parser():
|
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
help="The path to the local model directory or Hugging Face repo.",
|
|
)
|
|
|
|
# Training args
|
|
parser.add_argument(
|
|
"--train",
|
|
action="store_true",
|
|
help="Do training",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--data",
|
|
type=str,
|
|
help=(
|
|
"Directory with {train, valid, test}.jsonl files or the name "
|
|
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--fine-tune-type",
|
|
type=str,
|
|
choices=["lora", "dora", "full"],
|
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-prompt",
|
|
action="store_true",
|
|
help="Mask the prompt in the loss when training",
|
|
default=None,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--training-mode",
|
|
type=str,
|
|
choices=["normal", "grpo"],
|
|
help="Training mode: normal or GRPO",
|
|
)
|
|
parser.add_argument(
|
|
"--num-layers",
|
|
type=int,
|
|
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
|
)
|
|
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
|
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
|
parser.add_argument(
|
|
"--val-batches",
|
|
type=int,
|
|
help="Number of validation batches, -1 uses the entire validation set.",
|
|
)
|
|
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
|
parser.add_argument(
|
|
"--steps-per-report",
|
|
type=int,
|
|
help="Number of training steps between loss reporting.",
|
|
)
|
|
parser.add_argument(
|
|
"--steps-per-eval",
|
|
type=int,
|
|
help="Number of training steps between validations.",
|
|
)
|
|
parser.add_argument(
|
|
"--resume-adapter-file",
|
|
type=str,
|
|
help="Load path to resume training from the given fine-tuned weights.",
|
|
)
|
|
parser.add_argument(
|
|
"--adapter-path",
|
|
type=str,
|
|
help="Save/load path for the fine-tuned weights.",
|
|
)
|
|
parser.add_argument(
|
|
"--save-every",
|
|
type=int,
|
|
help="Save the model every N iterations.",
|
|
)
|
|
parser.add_argument(
|
|
"--test",
|
|
action="store_true",
|
|
help="Evaluate on the test set after training",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--test-batches",
|
|
type=int,
|
|
help="Number of test set batches, -1 uses the entire test set.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-seq-length",
|
|
type=int,
|
|
help="Maximum sequence length.",
|
|
)
|
|
parser.add_argument(
|
|
"-c",
|
|
"--config",
|
|
type=str,
|
|
help="A YAML configuration file with the training options",
|
|
)
|
|
parser.add_argument(
|
|
"--grad-checkpoint",
|
|
action="store_true",
|
|
help="Use gradient checkpointing to reduce memory use.",
|
|
default=None,
|
|
)
|
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
|
|
|
# GRPO args
|
|
parser.add_argument(
|
|
"--group-size",
|
|
type=int,
|
|
help="Number of generations.",
|
|
default=4,
|
|
)
|
|
parser.add_argument(
|
|
"--max-completion-length",
|
|
type=int,
|
|
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
|
|
default=512,
|
|
)
|
|
parser.add_argument(
|
|
"--beta",
|
|
type=float,
|
|
help="KL penalty coefficient.",
|
|
default=0.1,
|
|
)
|
|
parser.add_argument(
|
|
"--epsilon",
|
|
type=float,
|
|
help="The Epsilon for numerical stability.",
|
|
default=1e-4,
|
|
)
|
|
parser.add_argument(
|
|
"--use-chat-template",
|
|
action="store_true",
|
|
help="If the model is a Chat model, use the Chat template.",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--use-prompt",
|
|
action="store_true",
|
|
help="Rather to use the prompt from the R1 paper.",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
help="Temperature for sampling. The higher the temperature, the more random the completions.",
|
|
default=1.0,
|
|
)
|
|
parser.add_argument(
|
|
"--reward-weights",
|
|
type=str,
|
|
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
|
|
default=None,
|
|
)
|
|
return parser
|
|
|
|
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
|
|
training_args = GRPOTrainingArgs(
|
|
batch_size=args.batch_size,
|
|
iters=args.iters,
|
|
val_batches=args.val_batches,
|
|
steps_per_report=args.steps_per_report,
|
|
steps_per_eval=args.steps_per_eval,
|
|
steps_per_save=args.save_every,
|
|
adapter_file=adapter_file,
|
|
max_seq_length=args.max_seq_length,
|
|
max_completion_length=args.max_completion_length,
|
|
grad_checkpoint=args.grad_checkpoint,
|
|
beta=args.beta,
|
|
group_size=args.group_size,
|
|
epsilon=args.epsilon,
|
|
reference_model_path=args.reference_model_path,
|
|
temperature=args.temperature,
|
|
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
|
|
)
|
|
|
|
if args.reference_model_path:
|
|
reference_model, _ = load(args.reference_model_path)
|
|
else:
|
|
reference_model, _ = load(args.model)
|
|
|
|
train_grpo(
|
|
model=model,
|
|
ref_model=reference_model.freeze(),
|
|
tokenizer=tokenizer,
|
|
optimizer=opt,
|
|
train_dataset=train_set,
|
|
val_dataset=valid_set,
|
|
args=training_args,
|
|
training_callback=training_callback,
|
|
)
|
|
|
|
def train_model(
|
|
args,
|
|
model: nn.Module,
|
|
tokenizer: TokenizerWrapper,
|
|
train_set,
|
|
valid_set,
|
|
training_callback: TrainingCallback = None,
|
|
):
|
|
model.freeze()
|
|
if args.num_layers > len(model.layers):
|
|
raise ValueError(
|
|
f"Requested to train {args.num_layers} layers "
|
|
f"but the model only has {len(model.layers)} layers."
|
|
)
|
|
|
|
if args.fine_tune_type == "full":
|
|
for l in model.layers[-max(args.num_layers, 0) :]:
|
|
l.unfreeze()
|
|
elif args.fine_tune_type in ["lora", "dora"]:
|
|
# Convert linear layers to lora/dora layers and unfreeze in the process
|
|
linear_to_lora_layers(
|
|
model,
|
|
args.num_layers,
|
|
args.lora_parameters,
|
|
use_dora=(args.fine_tune_type == "dora"),
|
|
)
|
|
else:
|
|
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
|
|
|
|
# Resume from weights if provided
|
|
if args.resume_adapter_file is not None:
|
|
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
|
|
model.load_weights(args.resume_adapter_file, strict=False)
|
|
|
|
print_trainable_parameters(model)
|
|
|
|
adapter_path = Path(args.adapter_path)
|
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
adapter_file = adapter_path / "adapters.safetensors"
|
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
|
|
|
model.train()
|
|
opt = optim.Adam(
|
|
learning_rate=(
|
|
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
|
)
|
|
)
|
|
|
|
# Train model
|
|
if args.training_mode == "grpo":
|
|
train_model_grpo(
|
|
model,
|
|
tokenizer,
|
|
args,
|
|
opt,
|
|
train_set,
|
|
valid_set,
|
|
adapter_file,
|
|
training_callback
|
|
)
|
|
else:
|
|
training_args = TrainingArgs(
|
|
batch_size=args.batch_size,
|
|
iters=args.iters,
|
|
val_batches=args.val_batches,
|
|
steps_per_report=args.steps_per_report,
|
|
steps_per_eval=args.steps_per_eval,
|
|
steps_per_save=args.save_every,
|
|
adapter_file=adapter_file,
|
|
max_seq_length=args.max_seq_length,
|
|
grad_checkpoint=args.grad_checkpoint
|
|
)
|
|
|
|
train(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
args=training_args,
|
|
optimizer=opt,
|
|
train_dataset=train_set,
|
|
val_dataset=valid_set,
|
|
training_callback=training_callback,
|
|
)
|
|
|
|
|
|
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
|
model.eval()
|
|
|
|
if args.training_mode == "grpo":
|
|
if args.reference_model_path:
|
|
reference_model, _ = load(args.reference_model_path)
|
|
else:
|
|
reference_model, _ = load(args.model)
|
|
|
|
test_loss, _, test_rewards = evaluate_grpo(
|
|
model=model,
|
|
ref_model=reference_model.freeze(),
|
|
dataset=test_set,
|
|
tokenizer=tokenizer,
|
|
batch_size=args.batch_size,
|
|
num_batches=args.test_batches,
|
|
max_seq_length=args.max_seq_length,
|
|
beta=args.beta,
|
|
group_size=args.group_size,
|
|
epsilon=args.epsilon,
|
|
temperature=args.temperature,
|
|
max_tokens=args.max_seq_length
|
|
)
|
|
|
|
test_ppl = math.exp(test_loss)
|
|
|
|
rewards_str = ", ".join([f"{k}: {v:.3f}" for k, v in test_rewards.items()])
|
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {rewards_str}")
|
|
else:
|
|
test_loss = evaluate(
|
|
model=model,
|
|
dataset=test_set,
|
|
tokenizer=tokenizer,
|
|
batch_size=args.batch_size,
|
|
num_batches=args.test_batches,
|
|
max_seq_length=args.max_seq_length,
|
|
)
|
|
|
|
test_ppl = math.exp(test_loss)
|
|
|
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
|
|
|
|
|
def run(args, training_callback: TrainingCallback = None):
|
|
np.random.seed(args.seed)
|
|
|
|
print("Loading pretrained model")
|
|
model, tokenizer = load(args.model)
|
|
|
|
print("Loading datasets")
|
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
|
|
|
if args.test and not args.train:
|
|
# Allow testing without LoRA layers by providing empty path
|
|
if args.adapter_path != "":
|
|
load_adapters(model, args.adapter_path)
|
|
|
|
elif args.train:
|
|
print("Training")
|
|
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
|
|
else:
|
|
raise ValueError("Must provide at least one of --train or --test")
|
|
|
|
if args.test:
|
|
print("Testing")
|
|
evaluate_model(args, model, tokenizer, test_set)
|
|
|
|
|
|
def main():
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
config = args.config
|
|
args = vars(args)
|
|
if config:
|
|
print("Loading configuration file", config)
|
|
with open(config, "r") as file:
|
|
config = yaml.load(file, yaml_loader)
|
|
# Prefer parameters from command-line arguments
|
|
for k, v in config.items():
|
|
if args.get(k, None) is None:
|
|
args[k] = v
|
|
|
|
# Update defaults for unspecified parameters
|
|
for k, v in CONFIG_DEFAULTS.items():
|
|
if args.get(k, None) is None:
|
|
args[k] = v
|
|
run(types.SimpleNamespace(**args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_grpo_loss_and_grad(
|
|
model,
|
|
ref_model,
|
|
completion_tensors,
|
|
prompt_texts,
|
|
answer_texts,
|
|
beta=0.1,
|
|
epsilon=1e-4,
|
|
reward_funcs=None,
|
|
reward_weights=None
|
|
):
|
|
"""
|
|
Compute GRPO loss and gradients using pre-generated completions.
|
|
|
|
Args:
|
|
model: The policy model
|
|
ref_model: The reference model
|
|
completion_tensors: List of tensors containing generated completions
|
|
prompt_texts: List of prompt texts
|
|
answer_texts: List of answer texts
|
|
beta: KL penalty coefficient
|
|
epsilon: Numerical stability constant
|
|
reward_funcs: List of reward functions
|
|
reward_weights: Optional weights for reward functions
|
|
"""
|
|
# Ensure model is in training mode for gradient computation
|
|
model.train()
|
|
|
|
# Get completion texts for reward calculation
|
|
completion_texts = [tokenizer.decode(comp.tolist()) for comp in completion_tensors]
|
|
|
|
# Prepare inputs for loss computation
|
|
max_length = max(tensor.shape[0] for tensor in completion_tensors)
|
|
padded_completions = []
|
|
attention_masks = []
|
|
|
|
for completion_tensor in completion_tensors:
|
|
padding_length = max_length - completion_tensor.shape[0]
|
|
if padding_length > 0:
|
|
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
|
|
padded_ids = mx.concatenate([completion_tensor, padding])
|
|
mask = mx.concatenate(
|
|
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
|
|
)
|
|
else:
|
|
padded_ids = completion_tensor
|
|
mask = mx.ones_like(completion_tensor)
|
|
padded_completions.append(padded_ids)
|
|
attention_masks.append(mask)
|
|
|
|
inputs = mx.stack(padded_completions)
|
|
attention_mask = mx.stack(attention_masks)
|
|
lengths = attention_mask.sum(axis=1)
|
|
|
|
# Compute log probabilities for both models
|
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
|
|
|
if ref_model is None:
|
|
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in token_log_probs]
|
|
else:
|
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
|
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in ref_token_log_probs]
|
|
|
|
# Pad log probabilities to same length
|
|
max_len = max(x.shape[0] for x in token_log_probs)
|
|
padded_log_probs = []
|
|
padded_ref_log_probs = []
|
|
|
|
for i in range(len(token_log_probs)):
|
|
seq_len = token_log_probs[i].shape[0]
|
|
padding = mx.zeros((max_len - seq_len,))
|
|
|
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
|
|
|
token_log_probs = mx.stack(padded_log_probs)
|
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
|
|
|
# Calculate rewards
|
|
all_func_rewards = []
|
|
for reward_func in reward_funcs:
|
|
func_rewards = mx.array(
|
|
reward_func(
|
|
prompts=prompt_texts,
|
|
completions=completion_texts,
|
|
answer=answer_texts,
|
|
)
|
|
)
|
|
all_func_rewards.append(func_rewards)
|
|
|
|
# Stack rewards and apply weights
|
|
rewards = mx.stack(all_func_rewards, axis=1)
|
|
if reward_weights is not None:
|
|
if len(reward_weights) != len(reward_funcs):
|
|
raise ValueError(
|
|
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
|
|
f"functions ({len(reward_funcs)})"
|
|
)
|
|
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
|
else:
|
|
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
|
|
|
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
|
|
|
# Group rewards by prompt (assuming completions are grouped by prompt)
|
|
group_size = len(completion_tensors) // len(prompt_texts)
|
|
if len(completion_tensors) % len(prompt_texts) != 0:
|
|
raise ValueError("Number of completions must be divisible by number of prompts")
|
|
|
|
rewards_by_group = []
|
|
for i in range(0, len(rewards), group_size):
|
|
rewards_by_group.append(rewards[i:i+group_size])
|
|
|
|
# Calculate advantages
|
|
advantages = mx.zeros_like(rewards)
|
|
for i, group_rewards in enumerate(rewards_by_group):
|
|
if len(group_rewards) > 1: # Only normalize if we have multiple samples
|
|
mean_reward = mx.mean(group_rewards)
|
|
std_reward = mx.std(group_rewards)
|
|
|
|
for j in range(group_size):
|
|
idx = i * group_size + j
|
|
advantages[idx] = (group_rewards[j] - mean_reward) / (std_reward + epsilon)
|
|
else:
|
|
# If only one sample, advantage is 0
|
|
advantages[i * group_size] = 0.0
|
|
|
|
# Compute KL divergence
|
|
kl_div = (
|
|
mx.exp(ref_token_log_probs - token_log_probs)
|
|
- (ref_token_log_probs - token_log_probs)
|
|
- 1
|
|
)
|
|
|
|
# Create mask for valid tokens
|
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
|
|
|
# Compute policy ratio
|
|
policy_ratio = mx.exp(token_log_probs - ref_token_log_probs)
|
|
|
|
# Compute per-token loss
|
|
per_token_loss = -(
|
|
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
|
)
|
|
|
|
# Average over tokens
|
|
sequence_sums = per_token_loss.sum(axis=1)
|
|
sequence_lengths = length_mask.sum(axis=1)
|
|
loss = (sequence_sums / sequence_lengths).mean()
|
|
|
|
# Calculate metrics for reporting
|
|
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
|
|
|
metrics = {
|
|
"total_rewards_mean": mx.mean(rewards),
|
|
"total_rewards_std": mx.std(rewards),
|
|
"kl": mean_kl,
|
|
}
|
|
|
|
for i, reward_func in enumerate(reward_funcs):
|
|
func_name = reward_func.__name__
|
|
func_rewards = all_func_rewards[i]
|
|
metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
|
|
metrics[f"{func_name}_std"] = mx.std(func_rewards)
|
|
|
|
return loss, sequence_lengths.sum(), metrics |