This commit is contained in:
Goekdeniz-Guelmez
2025-01-19 01:58:29 +01:00
parent 7d279b51ef
commit fa80d081f2
4 changed files with 372 additions and 188 deletions

View File

@@ -12,7 +12,6 @@ import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from ..generate import generate
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs

View File

@@ -1,54 +1,48 @@
# Copyright © 2024 Apple Inc.
import time
from pathlib import Path
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from dpo_trainer import DPOTrainingArgs, iterate_dpo_batches, train_dpo, TrainingCallback
import mlx.nn as nn
import mlx.core as mx
import numpy as np
from mlx.utils import tree_flatten
from mlx.nn.utils import average_gradients
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint
@dataclass
class ORPOTrainingArgs(DPOTrainingArgs):
"""
Training arguments specific to ORPO, extending DPO arguments.
"""
mu: float = field(
default=0.5,
metadata={"help": "ORPO KL divergence weight parameter"}
reward_scaling: float = field(
default=1.0,
metadata={"help": "Scaling factor for offline rewards."}
)
def orpo_loss(
model,
reference_teacher_model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
chosen_rewards: mx.array, # Pre-computed rewards
rejected_rewards: mx.array, # Pre-computed rewards
beta: float,
delta: float,
mu: float = 0.5, # ORPO hyperparameter for balancing KL divergence
loss_type: str = "sigmoid",
is_reference_free: bool = False
reward_scaling: float = 1.0,
):
"""
Calculate ORPO loss for inputs.
ORPO extends DPO by adding a KL regularization term to prevent overfitting to preferences.
Calculate ORPO loss using pre-computed rewards.
Args:
model: Policy model
reference_teacher_model: Reference model
chosen: Chosen sequence tokens
rejected: Rejected sequence tokens
chosen_masks: Masks for chosen sequences
rejected_masks: Masks for rejected sequences
chosen_masks: Attention masks for chosen sequences
rejected_masks: Attention masks for rejected sequences
chosen_rewards: Pre-computed rewards for chosen sequences
rejected_rewards: Pre-computed rewards for rejected sequences
beta: Temperature parameter
delta: Margin for DPOP loss type
mu: ORPO hyperparameter for balancing KL divergence (default: 0.5)
loss_type: Loss type ('sigmoid', 'hinge', 'ipo', or 'dpop')
is_reference_free: Whether to use reference-free training
reward_scaling: Scaling factor for rewards
Returns:
Tuple of (loss, reward, num_tokens)
Loss value, rewards, and number of tokens.
"""
def make_predictions(model, x, mask):
inputs = x[:, :-1]
@@ -59,218 +53,335 @@ def orpo_loss(
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)
num_rejected_tokens = rejected_masks.sum(-1)
# Calculate log probabilities for policy model
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
# Calculate reference model scores
if not is_reference_free:
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
else:
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
# Compute average log probabilities if using IPO loss
if loss_type == "ipo":
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
else:
policy_chosen_score = policy_chosen_scores.sum(-1)
policy_rejected_score = policy_rejected_scores.sum(-1)
reference_chosen_score = reference_chosen_scores.sum(-1)
reference_rejected_score = reference_rejected_scores.sum(-1)
# Calculate preference logits
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
# Calculate preference loss based on loss type
if loss_type == "sigmoid":
preference_loss = -nn.log_sigmoid(beta * logits)
elif loss_type == "hinge":
preference_loss = nn.relu(1 - beta * logits)
elif loss_type == "ipo":
preference_loss = (logits - 1 / (2 * beta)) ** 2
elif loss_type == "dpop":
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
preference_loss = -(nn.log_sigmoid(beta * logits) - delta * penalty)
else:
raise ValueError(f"Unknown loss type: {loss_type}")
# Calculate KL divergence term for ORPO
kl_div_chosen = mx.mean((policy_chosen_scores - reference_chosen_scores) ** 2)
kl_div_rejected = mx.mean((policy_rejected_scores - reference_rejected_scores) ** 2)
kl_regularization = mu * (kl_div_chosen + kl_div_rejected)
# Combine preference loss and KL regularization
loss = mx.mean(preference_loss) + kl_regularization
# Scale the pre-computed rewards
chosen_rewards = chosen_rewards * reward_scaling
rejected_rewards = rejected_rewards * reward_scaling
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
# Calculate rewards for monitoring
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
reward = mx.stack([chosen_reward, rejected_reward])
# ORPO uses the reward difference directly
reward_diff = chosen_rewards - rejected_rewards
# Calculate ORPO loss using logistic function
policy_diff = policy_chosen_scores.sum(-1) - policy_rejected_scores.sum(-1)
loss = -nn.log_sigmoid(beta * (policy_diff * reward_diff))
loss = mx.mean(loss)
# Calculate number of tokens for logging
num_tokens = (chosen_masks.sum() + rejected_masks.sum())
# Calculate rewards for logging
avg_chosen_reward = mx.mean(chosen_rewards)
avg_rejected_reward = mx.mean(rejected_rewards)
reward = mx.stack([avg_chosen_reward, avg_rejected_reward])
return loss, reward, num_tokens
def evaluate_orpo(
model,
reference_model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
delta: float,
mu: float = 0.5,
reward_scaling: float = 1.0,
max_seq_length=2048,
loss_type="sigmoid",
is_reference_free=False,
):
"""
Evaluate model using ORPO metrics.
Args:
model: Policy model to evaluate
reference_model: Reference model for comparison
dataset: Evaluation dataset
tokenizer: Tokenizer for processing text
batch_size: Batch size for evaluation
num_batches: Number of batches to evaluate (-1 for full dataset)
beta: Temperature parameter
delta: Margin for DPOP loss
mu: ORPO KL divergence weight
max_seq_length: Maximum sequence length
loss_type: Type of loss function
is_reference_free: Whether to use reference-free evaluation
Returns:
Tuple of (loss, rewards, kl_metrics), where:
- loss is the total ORPO loss
- rewards is [chosen_reward, rejected_reward]
- kl_metrics is [chosen_kl, rejected_kl]
Evaluation function for ORPO.
"""
all_losses = 0
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
all_kl_divs = mx.zeros((2,)) # [chosen_kl, rejected_kl]
all_rewards = mx.zeros((2,))
ntokens = 0
def compute_kl_divergence(policy_scores, reference_scores, masks):
"""Helper function to compute KL divergence metrics."""
# Using MSE as a proxy for KL divergence as in the loss function
valid_tokens = masks.sum()
kl_div = ((policy_scores - reference_scores) ** 2 * masks).sum() / valid_tokens
return kl_div
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_dpo_batches( # Reusing DPO batch iterator
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks = batch
# Get model predictions
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
# Get scores for both models
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
if not is_reference_free:
reference_chosen_scores = mx.stop_gradient(
make_predictions(reference_model, chosen, chosen_masks)
)
reference_rejected_scores = mx.stop_gradient(
make_predictions(reference_model, rejected, rejected_masks)
)
else:
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
# Compute KL divergences
chosen_kl = compute_kl_divergence(
policy_chosen_scores, reference_chosen_scores, chosen_masks[:, :-1]
)
rejected_kl = compute_kl_divergence(
policy_rejected_scores, reference_rejected_scores, rejected_masks[:, :-1]
)
all_kl_divs += mx.stack([chosen_kl, rejected_kl])
# Compute ORPO loss and rewards
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
loss, reward, toks = orpo_loss(
model=model,
reference_teacher_model=reference_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=beta,
delta=delta,
mu=mu,
loss_type=loss_type,
is_reference_free=is_reference_free,
reward_scaling=reward_scaling,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
mx.eval(all_losses, all_rewards, all_kl_divs, ntokens)
mx.eval(all_losses, all_rewards, ntokens)
# Aggregate metrics across distributed workers if necessary
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
all_kl_divs = mx.distributed.all_sum(all_kl_divs)
ntokens = mx.distributed.all_sum(ntokens)
# Normalize metrics
avg_loss = (all_losses / ntokens).item()
avg_rewards = [r / mx.distributed.init().size() for r in all_rewards.tolist()]
avg_kl_divs = [kl / mx.distributed.init().size() for kl in all_kl_divs.tolist()]
return (all_losses / ntokens).item(), all_rewards.tolist()
return avg_loss, avg_rewards, avg_kl_divs
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Modified batch iterator for ORPO that includes pre-computed rewards.
Works with pre-tokenized input data.
"""
# Sort pairs by length of the chosen response
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
# Get lengths assuming data is already tokenized
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
max_length = max(max(chosen_lengths), max(rejected_lengths))
if max_length > max_seq_length:
print(
f"[WARNING] Sequences longer than {max_seq_length} tokens "
f"will be truncated."
)
pad_to = 8
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
# Always use binary rewards
chosen_rewards = np.ones((batch_size // step,), np.float32)
rejected_rewards = np.zeros((batch_size // step,), np.float32)
for j in range(batch_size // step):
# Use pre-tokenized sequences directly
chosen_length = min(chosen_lengths[j], max_seq_length)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
chosen_masks[j, :chosen_length] = 1.0
rejected_length = min(rejected_lengths[j], max_seq_length)
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
rejected_masks[j, :rejected_length] = 1.0
yield (mx.array(chosen_arr), mx.array(rejected_arr),
mx.array(chosen_masks), mx.array(rejected_masks),
mx.array(chosen_rewards), mx.array(rejected_rewards))
if not train:
break
def train_orpo(
model,
reference_model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: ORPOTrainingArgs = ORPOTrainingArgs(),
training_callback: TrainingCallback = None,
training_callback = None,
):
"""
Train a model using ORPO (Offline Rejection Preference Optimization).
This function adapts the DPO training loop to use ORPO loss.
Training function for ORPO.
"""
return train_dpo(
model=model,
reference_model=reference_model,
tokenizer=tokenizer,
optimizer=optimizer,
train_dataset=train_dataset,
val_dataset=val_dataset,
args=args,
loss_fn=orpo_loss,
training_callback=training_callback,
loss_type=args.loss_type,
)
print(f"Starting ORPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
(loss, reward, toks), grad = loss_value_and_grad(
model,
chosen,
rejected,
chosen_masks,
rejected_masks,
chosen_rewards,
rejected_rewards
)
grad = average_gradients(grad)
optimizer.update(model, grad)
return loss, reward, toks
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
chosen_rewards, rejected_rewards):
return orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=args.beta,
reward_scaling=args.reward_scaling
)
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
# Training loop with progress tracking
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
trained_tokens = 0
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_orpo_batches( # reuse DPO batch iterator
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
# Evaluate if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_rewards = evaluate_orpo(
model=model,
dataset=val_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
reward_scaling=args.reward_scaling,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
training_callback.on_val_loss_report({
"iteration": it,
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
"val_time": val_time,
})
start = time.perf_counter()
# Training step
loss, reward, toks = step(batch)
losses += loss
rewards += reward
n_tokens += toks
steps += 1
mx.eval(state, losses, rewards, n_tokens)
# Report training metrics if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()]
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
training_callback.on_train_loss_report({
"iteration": it,
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
})
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save model weights if needed
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")