removing dpo and fixing some stuff for orpo

This commit is contained in:
Goekdeniz-Guelmez 2025-01-24 16:09:22 +01:00
parent 0bb001121e
commit e3688293ed
4 changed files with 153 additions and 714 deletions

View File

@ -15,7 +15,6 @@ import yaml
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo
from .tuner.utils import (
build_schedule,
@ -176,7 +175,7 @@ def build_parser():
default=None,
)
parser.add_argument("--beta", type=float)
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpo"])
parser.add_argument("--is-reference-free", action="store_true")
parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str)
@ -229,40 +228,7 @@ def train_model(
)
# Train model based on training mode
if args.training_mode == "dpo":
training_args = DPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
loss_type=args.dpo_loss_type,
is_reference_free=args.is_reference_free,
delta=args.delta,
reference_model_path=args.reference_model_path,
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_dpo(
model=model,
reference_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
elif args.training_mode == "orpo":
if args.training_mode == "orpo":
training_args = ORPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
@ -273,8 +239,7 @@ def train_model(
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
reward_scaling=args.reward_scaling,
beta=args.beta
)
train_orpo(
@ -284,7 +249,7 @@ def train_model(
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
training_callback=training_callback
)
else:
training_args = TrainingArgs(
@ -313,26 +278,7 @@ def train_model(
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
if args.training_mode == "dpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model
test_loss, test_rewards = evaluate_dpo(
model=model,
reference_model=reference_model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
delta=args.delta,
loss_type=args.dpo_loss_type,
)
print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
elif args.training_mode == "orpo":
if args.training_mode == "orpo":
test_loss, test_rewards = evaluate_orpo(
model=model,
dataset=test_set,
@ -340,8 +286,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
reward_scaling=args.reward_scaling,
beta=args.beta
)
print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:

View File

@ -4,70 +4,47 @@ from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class DPODataset:
"""
A dataset for DPO (Direct Preference Optimization) training that handles
prompt-chosen-rejected triplets with optional scores in the format:
{"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...}
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
score_chosen_key: str = "score_chosen",
score_rejected_key: str = "score_rejected",
):
class ORPODataset:
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
preference_score_key: str = "preference_score"
):
self._chosen_data = []
self._rejected_data = []
self._scores = []
for d in data:
# Process the text data
chosen_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
],
)
rejected_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
chosen_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
])
rejected_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
])
self._chosen_data.append(chosen_text)
self._rejected_data.append(rejected_text)
# Handle scores if they exist
if score_chosen_key in d and score_rejected_key in d:
chosen_score = float(d[score_chosen_key])
rejected_score = float(d[score_rejected_key])
# Normalize scores to [0, 1] range
score_diff = chosen_score - rejected_score
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
normalized_score = (score_diff / max_diff + 1) / 2
self._scores.append(normalized_score)
if preference_score_key in d:
self._scores.append(float(d[preference_score_key]))
else:
# Default to binary preference (1.0) if no scores provided
self._scores.append(1.0)
def __getitem__(self, idx: int):
return {
"chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx],
"preference_score": self._scores[idx]
}
def __len__(self):
return len(self._chosen_data)
def __getitem__(self, idx: int):
return {
"chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx],
"preference_score": self._scores[idx]
}
def __len__(self):
return len(self._chosen_data)
class Dataset:
@ -158,7 +135,7 @@ def create_dataset(
# Add DPO dataset support
if "chosen" in sample and "rejected" in sample:
return DPODataset(data, tokenizer)
return ORPODataset(data, tokenizer)
elif "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:

View File

@ -1,457 +0,0 @@
# Copyright © 2024 Apple Inc.
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
@dataclass
class DPOTrainingArgs(TrainingArgs):
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
)
loss_type: str = field(
default="sigmoid",
metadata={
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
}
)
is_reference_free: bool = field(
default=False,
metadata={
"help": "Whether to use reference-free DPO training."
}
)
delta: float = field(
default=50.0,
metadata={
"help": "Delta parameter for DPOP loss type."
}
)
reference_model_path: str = field(
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
}
)
seed: int = field(
default=42,
metadata={
"help": "Random seed for reproducibility."
}
)
def dpo_loss(
model,
reference_teacher_model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
beta: float,
delta: float,
loss_type: str = "sigmoid",
is_reference_free: bool = False
):
"""
Calculate loss for inputs.
Args:
inputs: Input tokens.
targets: Target tokens.
lengths: Lengths of inputs.
Returns:
Loss value.
"""
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)
num_rejected_tokens = rejected_masks.sum(-1)
# Calculate log probabilities for policy model
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
if loss_type == "ipo":
# ipo uses average log probabilities
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
else:
policy_chosen_score = policy_chosen_scores.sum(-1)
policy_rejected_score = policy_rejected_scores.sum(-1)
# Calculate log probabilities for reference model
if is_reference_free:
reference_chosen_score = mx.zeros_like(policy_chosen_score)
reference_rejected_score = mx.zeros_like(policy_rejected_score)
else:
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
if loss_type == "ipo":
# ipo uses average log probabilities
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
else:
reference_chosen_score = reference_chosen_scores.sum(-1)
reference_rejected_score = reference_rejected_scores.sum(-1)
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
if loss_type == "sigmoid":
losses = -nn.log_sigmoid(beta * logits)
elif loss_type == "hinge":
losses = nn.relu(1 - beta * logits)
elif loss_type == "ipo":
losses = (logits - 1 / (2 * beta)) ** 2
elif loss_type == "dpop":
delta = 50
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
losses = -(nn.log_sigmoid(beta * logits) - delta * penalty)
else:
raise ValueError(f"Unknown loss type: {loss_type}")
loss = mx.mean(losses)
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
reward = mx.stack([chosen_reward, rejected_reward])
return loss, reward, num_tokens
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Modified iterate_batches for DPO training that handles chosen and rejected samples.
"""
# Sort pairs by length of the chosen response
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
# Get lengths for chosen and rejected sequences
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
max_length = max(max(chosen_lengths), max(rejected_lengths))
if max_length > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sequence {max_length} will be truncated to {max_seq_length}."
)
# Pad to nearest multiple of 8
pad_to = 8
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
# Create arrays for chosen and rejected sequences
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
# Create attention masks
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
for j in range(batch_size // step):
# Process chosen sequence
chosen_length = min(chosen_lengths[j], max_seq_length)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
chosen_masks[j, :chosen_length] = 1.0
# Process rejected sequence
rejected_length = min(rejected_lengths[j], max_seq_length)
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
rejected_masks[j, :rejected_length] = 1.0
yield (mx.array(chosen_arr), mx.array(rejected_arr),
mx.array(chosen_masks), mx.array(rejected_masks))
if not train:
break
def evaluate_dpo(
model,
reference_model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
delta: float,
max_seq_length=2048,
loss_fn: callable = dpo_loss,
loss_type="sigmoid",
):
"""
Modified evaluate function for DPO training.
"""
all_losses = 0
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_dpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks = batch
loss, reward, toks = loss_fn(
model=model,
reference_teacher_model=reference_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
loss_type=loss_type,
beta=beta,
delta=delta,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item(), all_rewards.tolist()
def train_dpo(
model,
reference_model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: DPOTrainingArgs = DPOTrainingArgs(),
loss_fn: callable = dpo_loss,
training_callback: TrainingCallback = None,
loss_type="sigmoid",
):
"""
Modified training function for DPO.
"""
print(f"Starting DPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks = batch
# Remove loss_type from the call
(loss, reward, toks), grad = loss_value_and_grad(
model,
reference_model,
chosen,
rejected,
chosen_masks,
rejected_masks
)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update
optimizer.update(model, grad)
return loss, reward, toks
# Create a wrapper function that includes all required arguments
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
return loss_fn(
model=model,
reference_teacher_model=ref_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
is_reference_free=args.is_reference_free
)
# Create value_and_grad with the wrapper
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
trained_tokens = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_dpo_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
# Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_rewards = evaluate_dpo(
model=model,
reference_model=reference_model,
dataset=val_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
loss_fn=loss_fn,
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
loss, reward, toks = step(batch)
losses += loss
rewards += reward
n_tokens += toks
steps += 1
mx.eval(state, losses, rewards, n_tokens)
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * world_size
train_rewards = mx.distributed.all_sum(rewards).tolist()
train_rewards = [r / (steps * world_size) for r in train_rewards]
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")

View File

@ -14,128 +14,48 @@ from .trainer import TrainingArgs, grad_checkpoint, TrainingCallback
class ORPOTrainingArgs(TrainingArgs):
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
)
reward_scaling: float = field(
default=1.0,
metadata={"help": "Scaling factor for offline rewards."}
metadata={"help": "Temperature parameter for ORPO training."}
)
def orpo_loss(
model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
chosen_rewards: mx.array,
rejected_rewards: mx.array,
beta: float = 0.1,
reward_scaling: float = 1.0,
):
"""
Calculate ORPO loss using pre-computed rewards that incorporate preference scores.
Args:
model: Policy model
chosen: Chosen sequence tokens
rejected: Rejected sequence tokens
chosen_masks: Attention masks for chosen sequences
rejected_masks: Attention masks for rejected sequences
chosen_rewards: Rewards for chosen sequences (derived from preference scores)
rejected_rewards: Rewards for rejected sequences (derived from preference scores)
beta: Temperature parameter
reward_scaling: Scaling factor for rewards
Returns:
Loss value, rewards, and number of tokens.
"""
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1):
def get_logps(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
seq_lengths = mask[:, :-1].sum(-1)
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
return logp_sum, logits_mean
# Calculate log probabilities for policy model
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
# Scale the pre-computed rewards
chosen_rewards = chosen_rewards * reward_scaling
rejected_rewards = rejected_rewards * reward_scaling
# Calculate reward difference
reward_diff = chosen_rewards - rejected_rewards
# Calculate ORPO loss using logistic function
policy_diff = policy_chosen_scores.sum(-1) - policy_rejected_scores.sum(-1)
loss = -nn.log_sigmoid(beta * (policy_diff * reward_diff))
loss = mx.mean(loss)
# Calculate number of tokens for logging
num_tokens = (chosen_masks.sum() + rejected_masks.sum())
# Calculate rewards for logging
avg_chosen_reward = mx.mean(chosen_rewards)
avg_rejected_reward = mx.mean(rejected_rewards)
reward = mx.stack([avg_chosen_reward, avg_rejected_reward])
return loss, reward, num_tokens
def evaluate_orpo(
model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
reward_scaling: float = 1.0,
max_seq_length=2048,
):
"""
Evaluation function for ORPO.
"""
all_losses = 0
all_rewards = mx.zeros((2,))
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
loss, reward, toks = orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=beta,
reward_scaling=reward_scaling,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item(), all_rewards.tolist()
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
)
ratio = nn.log_sigmoid(log_odds)
loss = -beta * ratio
accuracies = (log_odds > 0).astype(mx.float32)
margins = mx.mean(ratio)
metrics = {
'accuracies': mx.mean(accuracies),
'margins': margins,
'policy_rejected_logps': mx.mean(policy_rejected_logps),
'policy_chosen_logps': mx.mean(policy_chosen_logps),
'rejected_logits_mean': mx.mean(rejected_logits_mean),
'chosen_logits_mean': mx.mean(chosen_logits_mean)
}
chosen_reward = beta * policy_chosen_logps
rejected_reward = beta * policy_rejected_logps
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
num_tokens = chosen_masks.sum() + rejected_masks.sum()
return mx.mean(loss), reward, num_tokens, metrics
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
@ -188,10 +108,6 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
# Get preference scores and convert to rewards
preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32)
# Convert preference scores to chosen/rejected rewards
# When preference_score is 1.0, chosen_reward=1.0, rejected_reward=0.0
# When preference_score is 0.0, chosen_reward=0.0, rejected_reward=1.0
# When preference_score is 0.5, both rewards are 0.5
chosen_rewards = preference_scores
rejected_rewards = 1.0 - preference_scores
@ -218,6 +134,56 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
break
def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048):
all_losses = 0
all_rewards = mx.zeros((2,))
all_metrics = None
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
loss, reward, toks, metrics = orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=beta
)
all_losses += loss * toks
all_rewards += reward * toks
ntokens += toks
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_rewards = (all_rewards / ntokens).tolist()
avg_loss = (all_losses / ntokens).item()
return avg_loss, avg_rewards, ntokens, avg_metrics
def train_orpo(
model,
tokenizer,
@ -227,9 +193,6 @@ def train_orpo(
args: ORPOTrainingArgs = ORPOTrainingArgs(),
training_callback: TrainingCallback = None,
):
"""
Training function for ORPO.
"""
print(f"Starting ORPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
@ -246,7 +209,7 @@ def train_orpo(
def step(batch):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
(loss, reward, toks), grad = loss_value_and_grad(
(loss, reward, toks, metrics), grad = loss_value_and_grad(
model,
chosen,
rejected,
@ -259,7 +222,7 @@ def train_orpo(
grad = average_gradients(grad)
optimizer.update(model, grad)
return loss, reward, toks
return loss, reward, toks, metrics
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
chosen_rewards, rejected_rewards):
@ -271,8 +234,7 @@ def train_orpo(
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=args.beta,
reward_scaling=args.reward_scaling
beta=args.beta
)
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
@ -283,11 +245,19 @@ def train_orpo(
n_tokens = 0
steps = 0
trained_tokens = 0
accumulated_metrics = {
'accuracies': 0,
'margins': 0,
'policy_rejected_logps': 0,
'policy_chosen_logps': 0,
'rejected_logits_mean': 0,
'chosen_logits_mean': 0
}
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_orpo_batches( # reuse DPO batch iterator
iterate_orpo_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
@ -295,18 +265,16 @@ def train_orpo(
train=True,
),
):
# Evaluate if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_rewards = evaluate_orpo(
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo(
model=model,
dataset=val_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
reward_scaling=args.reward_scaling,
beta=args.beta
)
val_time = time.perf_counter() - stop
if rank == 0:
@ -315,6 +283,8 @@ def train_orpo(
f"Val loss {val_loss:.8f}, "
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val accuracy {val_metrics['accuracies']:.3f}, "
f"Val margin {val_metrics['margins']:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)
@ -325,25 +295,28 @@ def train_orpo(
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
**{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time,
})
start = time.perf_counter()
# Training step
loss, reward, toks = step(batch)
loss, reward, toks, metrics = step(batch)
losses += loss
rewards += reward
n_tokens += toks
steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, rewards, n_tokens)
# Report training metrics if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()]
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
@ -356,10 +329,11 @@ def train_orpo(
f"Iter {it}: Train loss {train_loss:.8f}, "
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
f"Accuracy {avg_metrics['accuracies']:.3f}, "
f"Margin {avg_metrics['margins']:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
@ -370,6 +344,7 @@ def train_orpo(
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
**{f"train_{k}": v for k, v in avg_metrics.items()},
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
@ -381,9 +356,9 @@ def train_orpo(
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
accumulated_metrics = {k: 0 for k in accumulated_metrics}
start = time.perf_counter()
# Save model weights if needed
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
@ -396,7 +371,6 @@ def train_orpo(
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")