mlx-examples/llms/mlx_lm/tuner/orpo_trainer.py

376 lines
14 KiB
Python
Raw Normal View History

2025-01-19 08:58:29 +08:00
import time
from pathlib import Path
2025-01-19 08:09:43 +08:00
from dataclasses import dataclass, field
2025-01-19 08:58:29 +08:00
import mlx.nn as nn
import mlx.core as mx
import numpy as np
from mlx.utils import tree_flatten
from mlx.nn.utils import average_gradients
2025-01-19 09:03:50 +08:00
from .trainer import TrainingArgs, grad_checkpoint, TrainingCallback
2025-01-19 08:09:43 +08:00
@dataclass
2025-01-19 09:03:31 +08:00
class ORPOTrainingArgs(TrainingArgs):
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for ORPO training."}
2025-01-19 09:03:31 +08:00
)
2025-01-19 08:09:43 +08:00
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1):
def get_logps(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
seq_lengths = mask[:, :-1].sum(-1)
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
return logp_sum, logits_mean
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
)
ratio = nn.log_sigmoid(log_odds)
loss = -beta * ratio
accuracies = (log_odds > 0).astype(mx.float32)
2025-01-24 23:57:18 +08:00
margins = mx.mean(ratio - 1)
metrics = {
'accuracies': mx.mean(accuracies),
'margins': margins,
'policy_rejected_logps': mx.mean(policy_rejected_logps),
'policy_chosen_logps': mx.mean(policy_chosen_logps),
'rejected_logits_mean': mx.mean(rejected_logits_mean),
'chosen_logits_mean': mx.mean(chosen_logits_mean)
}
chosen_reward = beta * policy_chosen_logps
rejected_reward = beta * policy_rejected_logps
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
num_tokens = chosen_masks.sum() + rejected_masks.sum()
return mx.mean(loss), reward, num_tokens, metrics
2025-01-19 08:58:29 +08:00
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
2025-01-19 20:45:33 +08:00
Modified batch iterator for ORPO that includes preference scores.
2025-01-19 08:58:29 +08:00
Works with pre-tokenized input data.
"""
# Sort pairs by length of the chosen response
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
# Get lengths assuming data is already tokenized
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
max_length = max(max(chosen_lengths), max(rejected_lengths))
if max_length > max_seq_length:
print(
f"[WARNING] Sequences longer than {max_seq_length} tokens "
f"will be truncated."
)
2025-01-19 08:09:43 +08:00
2025-01-19 08:58:29 +08:00
pad_to = 8
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
2025-01-19 20:45:33 +08:00
# Get preference scores and convert to rewards
2025-01-24 23:57:18 +08:00
preference_scores = [x.get('preference_score', 1.0) for x in batch]
chosen_rewards = np.array(preference_scores, np.float32)
rejected_rewards = np.array([1.0 - score for score in preference_scores], np.float32)
2025-01-19 08:58:29 +08:00
for j in range(batch_size // step):
# Use pre-tokenized sequences directly
chosen_length = min(chosen_lengths[j], max_seq_length)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
chosen_masks[j, :chosen_length] = 1.0
rejected_length = min(rejected_lengths[j], max_seq_length)
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
rejected_masks[j, :rejected_length] = 1.0
2025-01-19 20:45:33 +08:00
yield (
mx.array(chosen_arr),
mx.array(rejected_arr),
mx.array(chosen_masks),
mx.array(rejected_masks),
mx.array(chosen_rewards),
mx.array(rejected_rewards)
)
2025-01-19 08:58:29 +08:00
if not train:
break
2025-01-19 08:09:43 +08:00
def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048):
all_losses = 0
all_rewards = mx.zeros((2,))
all_metrics = None
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
loss, reward, toks, metrics = orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=beta
)
all_losses += loss * toks
all_rewards += reward * toks
ntokens += toks
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_rewards = (all_rewards / ntokens).tolist()
avg_loss = (all_losses / ntokens).item()
return avg_loss, avg_rewards, ntokens, avg_metrics
2025-01-19 08:09:43 +08:00
def train_orpo(
model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: ORPOTrainingArgs = ORPOTrainingArgs(),
2025-01-19 09:03:31 +08:00
training_callback: TrainingCallback = None,
2025-01-19 08:09:43 +08:00
):
2025-01-19 08:58:29 +08:00
print(f"Starting ORPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
(loss, reward, toks, metrics), grad = loss_value_and_grad(
2025-01-19 08:58:29 +08:00
model,
chosen,
rejected,
chosen_masks,
rejected_masks,
chosen_rewards,
rejected_rewards
)
grad = average_gradients(grad)
optimizer.update(model, grad)
return loss, reward, toks, metrics
2025-01-19 08:58:29 +08:00
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
chosen_rewards, rejected_rewards):
return orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=args.beta
2025-01-19 08:58:29 +08:00
)
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
# Training loop with progress tracking
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
trained_tokens = 0
accumulated_metrics = {
'accuracies': 0,
'margins': 0,
'policy_rejected_logps': 0,
'policy_chosen_logps': 0,
'rejected_logits_mean': 0,
'chosen_logits_mean': 0
}
2025-01-19 08:58:29 +08:00
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_orpo_batches(
2025-01-19 08:58:29 +08:00
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo(
2025-01-19 08:58:29 +08:00
model=model,
dataset=val_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
beta=args.beta
2025-01-19 08:58:29 +08:00
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
2025-01-23 04:39:29 +08:00
f"Val loss {val_loss:.8f}, "
2025-01-19 08:58:29 +08:00
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val accuracy {val_metrics['accuracies']:.3f}, "
f"Val margin {val_metrics['margins']:.3f}, "
2025-01-19 08:58:29 +08:00
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
training_callback.on_val_loss_report({
"iteration": it,
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
**{f"val_{k}": v for k, v in val_metrics.items()},
2025-01-19 08:58:29 +08:00
"val_time": val_time,
})
start = time.perf_counter()
# Training step
loss, reward, toks, metrics = step(batch)
2025-01-19 08:58:29 +08:00
losses += loss
rewards += reward
n_tokens += toks
steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
2025-01-19 08:58:29 +08:00
mx.eval(state, losses, rewards, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()]
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
2025-01-19 08:58:29 +08:00
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
2025-01-23 04:39:29 +08:00
f"Iter {it}: Train loss {train_loss:.8f}, "
2025-01-19 08:58:29 +08:00
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
f"Accuracy {avg_metrics['accuracies']:.3f}, "
f"Margin {avg_metrics['margins']:.3f}, "
2025-01-19 08:58:29 +08:00
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
training_callback.on_train_loss_report({
"iteration": it,
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
**{f"train_{k}": v for k, v in avg_metrics.items()},
2025-01-19 08:58:29 +08:00
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
})
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
accumulated_metrics = {k: 0 for k in accumulated_metrics}
2025-01-19 08:58:29 +08:00
start = time.perf_counter()
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")