removing dpo and fixing some stuff for orpo

This commit is contained in:
Goekdeniz-Guelmez 2025-01-24 16:09:22 +01:00
parent 0bb001121e
commit e3688293ed
4 changed files with 153 additions and 714 deletions

View File

@ -15,7 +15,6 @@ import yaml
from .tokenizer_utils import TokenizerWrapper from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo
from .tuner.utils import ( from .tuner.utils import (
build_schedule, build_schedule,
@ -176,7 +175,7 @@ def build_parser():
default=None, default=None,
) )
parser.add_argument("--beta", type=float) parser.add_argument("--beta", type=float)
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"]) parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpo"])
parser.add_argument("--is-reference-free", action="store_true") parser.add_argument("--is-reference-free", action="store_true")
parser.add_argument("--delta", type=float) parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str) parser.add_argument("--reference-model-path", type=str)
@ -229,40 +228,7 @@ def train_model(
) )
# Train model based on training mode # Train model based on training mode
if args.training_mode == "dpo": if args.training_mode == "orpo":
training_args = DPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
loss_type=args.dpo_loss_type,
is_reference_free=args.is_reference_free,
delta=args.delta,
reference_model_path=args.reference_model_path,
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_dpo(
model=model,
reference_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
elif args.training_mode == "orpo":
training_args = ORPOTrainingArgs( training_args = ORPOTrainingArgs(
batch_size=args.batch_size, batch_size=args.batch_size,
iters=args.iters, iters=args.iters,
@ -273,8 +239,7 @@ def train_model(
adapter_file=adapter_file, adapter_file=adapter_file,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint, grad_checkpoint=args.grad_checkpoint,
beta=args.beta, beta=args.beta
reward_scaling=args.reward_scaling,
) )
train_orpo( train_orpo(
@ -284,7 +249,7 @@ def train_model(
train_dataset=train_set, train_dataset=train_set,
val_dataset=valid_set, val_dataset=valid_set,
args=training_args, args=training_args,
training_callback=training_callback, training_callback=training_callback
) )
else: else:
training_args = TrainingArgs( training_args = TrainingArgs(
@ -313,26 +278,7 @@ def train_model(
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval() model.eval()
if args.training_mode == "dpo": if args.training_mode == "orpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model
test_loss, test_rewards = evaluate_dpo(
model=model,
reference_model=reference_model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
delta=args.delta,
loss_type=args.dpo_loss_type,
)
print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
elif args.training_mode == "orpo":
test_loss, test_rewards = evaluate_orpo( test_loss, test_rewards = evaluate_orpo(
model=model, model=model,
dataset=test_set, dataset=test_set,
@ -340,8 +286,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
batch_size=args.batch_size, batch_size=args.batch_size,
num_batches=args.test_batches, num_batches=args.test_batches,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
beta=args.beta, beta=args.beta
reward_scaling=args.reward_scaling,
) )
print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else: else:

View File

@ -4,70 +4,47 @@ from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
class ORPODataset:
class DPODataset: def __init__(
""" self,
A dataset for DPO (Direct Preference Optimization) training that handles data: List[Dict[str, str]],
prompt-chosen-rejected triplets with optional scores in the format: tokenizer: PreTrainedTokenizer,
{"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...} prompt_key: str = "prompt",
""" chosen_key: str = "chosen",
rejected_key: str = "rejected",
def __init__( preference_score_key: str = "preference_score"
self, ):
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
score_chosen_key: str = "score_chosen",
score_rejected_key: str = "score_rejected",
):
self._chosen_data = [] self._chosen_data = []
self._rejected_data = [] self._rejected_data = []
self._scores = [] self._scores = []
for d in data: for d in data:
# Process the text data chosen_text = tokenizer.apply_chat_template([
chosen_text = tokenizer.apply_chat_template( {"role": "user", "content": d[prompt_key]},
[ {"role": "assistant", "content": d[chosen_key]},
{"role": "user", "content": d[prompt_key]}, ])
{"role": "assistant", "content": d[chosen_key]}, rejected_text = tokenizer.apply_chat_template([
], {"role": "user", "content": d[prompt_key]},
) {"role": "assistant", "content": d[rejected_key]},
rejected_text = tokenizer.apply_chat_template( ])
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
self._chosen_data.append(chosen_text) self._chosen_data.append(chosen_text)
self._rejected_data.append(rejected_text) self._rejected_data.append(rejected_text)
# Handle scores if they exist if preference_score_key in d:
if score_chosen_key in d and score_rejected_key in d: self._scores.append(float(d[preference_score_key]))
chosen_score = float(d[score_chosen_key])
rejected_score = float(d[score_rejected_key])
# Normalize scores to [0, 1] range
score_diff = chosen_score - rejected_score
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
normalized_score = (score_diff / max_diff + 1) / 2
self._scores.append(normalized_score)
else: else:
# Default to binary preference (1.0) if no scores provided
self._scores.append(1.0) self._scores.append(1.0)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return { return {
"chosen": self._chosen_data[idx], "chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx], "rejected": self._rejected_data[idx],
"preference_score": self._scores[idx] "preference_score": self._scores[idx]
} }
def __len__(self): def __len__(self):
return len(self._chosen_data) return len(self._chosen_data)
class Dataset: class Dataset:
@ -158,7 +135,7 @@ def create_dataset(
# Add DPO dataset support # Add DPO dataset support
if "chosen" in sample and "rejected" in sample: if "chosen" in sample and "rejected" in sample:
return DPODataset(data, tokenizer) return ORPODataset(data, tokenizer)
elif "messages" in sample: elif "messages" in sample:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample: elif prompt_feature in sample and completion_feature in sample:

View File

@ -1,457 +0,0 @@
# Copyright © 2024 Apple Inc.
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
@dataclass
class DPOTrainingArgs(TrainingArgs):
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
)
loss_type: str = field(
default="sigmoid",
metadata={
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
}
)
is_reference_free: bool = field(
default=False,
metadata={
"help": "Whether to use reference-free DPO training."
}
)
delta: float = field(
default=50.0,
metadata={
"help": "Delta parameter for DPOP loss type."
}
)
reference_model_path: str = field(
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
}
)
seed: int = field(
default=42,
metadata={
"help": "Random seed for reproducibility."
}
)
def dpo_loss(
model,
reference_teacher_model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
beta: float,
delta: float,
loss_type: str = "sigmoid",
is_reference_free: bool = False
):
"""
Calculate loss for inputs.
Args:
inputs: Input tokens.
targets: Target tokens.
lengths: Lengths of inputs.
Returns:
Loss value.
"""
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)
num_rejected_tokens = rejected_masks.sum(-1)
# Calculate log probabilities for policy model
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
if loss_type == "ipo":
# ipo uses average log probabilities
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
else:
policy_chosen_score = policy_chosen_scores.sum(-1)
policy_rejected_score = policy_rejected_scores.sum(-1)
# Calculate log probabilities for reference model
if is_reference_free:
reference_chosen_score = mx.zeros_like(policy_chosen_score)
reference_rejected_score = mx.zeros_like(policy_rejected_score)
else:
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
if loss_type == "ipo":
# ipo uses average log probabilities
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
else:
reference_chosen_score = reference_chosen_scores.sum(-1)
reference_rejected_score = reference_rejected_scores.sum(-1)
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
if loss_type == "sigmoid":
losses = -nn.log_sigmoid(beta * logits)
elif loss_type == "hinge":
losses = nn.relu(1 - beta * logits)
elif loss_type == "ipo":
losses = (logits - 1 / (2 * beta)) ** 2
elif loss_type == "dpop":
delta = 50
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
losses = -(nn.log_sigmoid(beta * logits) - delta * penalty)
else:
raise ValueError(f"Unknown loss type: {loss_type}")
loss = mx.mean(losses)
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
reward = mx.stack([chosen_reward, rejected_reward])
return loss, reward, num_tokens
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Modified iterate_batches for DPO training that handles chosen and rejected samples.
"""
# Sort pairs by length of the chosen response
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
# Get lengths for chosen and rejected sequences
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
max_length = max(max(chosen_lengths), max(rejected_lengths))
if max_length > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sequence {max_length} will be truncated to {max_seq_length}."
)
# Pad to nearest multiple of 8
pad_to = 8
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
# Create arrays for chosen and rejected sequences
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
# Create attention masks
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
for j in range(batch_size // step):
# Process chosen sequence
chosen_length = min(chosen_lengths[j], max_seq_length)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
chosen_masks[j, :chosen_length] = 1.0
# Process rejected sequence
rejected_length = min(rejected_lengths[j], max_seq_length)
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
rejected_masks[j, :rejected_length] = 1.0
yield (mx.array(chosen_arr), mx.array(rejected_arr),
mx.array(chosen_masks), mx.array(rejected_masks))
if not train:
break
def evaluate_dpo(
model,
reference_model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
delta: float,
max_seq_length=2048,
loss_fn: callable = dpo_loss,
loss_type="sigmoid",
):
"""
Modified evaluate function for DPO training.
"""
all_losses = 0
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_dpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks = batch
loss, reward, toks = loss_fn(
model=model,
reference_teacher_model=reference_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
loss_type=loss_type,
beta=beta,
delta=delta,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item(), all_rewards.tolist()
def train_dpo(
model,
reference_model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: DPOTrainingArgs = DPOTrainingArgs(),
loss_fn: callable = dpo_loss,
training_callback: TrainingCallback = None,
loss_type="sigmoid",
):
"""
Modified training function for DPO.
"""
print(f"Starting DPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks = batch
# Remove loss_type from the call
(loss, reward, toks), grad = loss_value_and_grad(
model,
reference_model,
chosen,
rejected,
chosen_masks,
rejected_masks
)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update
optimizer.update(model, grad)
return loss, reward, toks
# Create a wrapper function that includes all required arguments
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
return loss_fn(
model=model,
reference_teacher_model=ref_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
is_reference_free=args.is_reference_free
)
# Create value_and_grad with the wrapper
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
trained_tokens = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_dpo_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
# Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_rewards = evaluate_dpo(
model=model,
reference_model=reference_model,
dataset=val_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
loss_fn=loss_fn,
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
loss, reward, toks = step(batch)
losses += loss
rewards += reward
n_tokens += toks
steps += 1
mx.eval(state, losses, rewards, n_tokens)
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * world_size
train_rewards = mx.distributed.all_sum(rewards).tolist()
train_rewards = [r / (steps * world_size) for r in train_rewards]
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")

View File

@ -14,128 +14,48 @@ from .trainer import TrainingArgs, grad_checkpoint, TrainingCallback
class ORPOTrainingArgs(TrainingArgs): class ORPOTrainingArgs(TrainingArgs):
beta: float = field( beta: float = field(
default=0.1, default=0.1,
metadata={"help": "Temperature parameter for DPO training."} metadata={"help": "Temperature parameter for ORPO training."}
)
reward_scaling: float = field(
default=1.0,
metadata={"help": "Scaling factor for offline rewards."}
) )
def orpo_loss( def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1):
model, def get_logps(model, x, mask):
chosen: mx.array, inputs = x[:, :-1]
rejected: mx.array, targets = x[:, 1:]
chosen_masks: mx.array, logits = model(inputs)
rejected_masks: mx.array, logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
chosen_rewards: mx.array, seq_lengths = mask[:, :-1].sum(-1)
rejected_rewards: mx.array, logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
beta: float = 0.1, logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
reward_scaling: float = 1.0, return logp_sum, logits_mean
):
"""
Calculate ORPO loss using pre-computed rewards that incorporate preference scores.
Args:
model: Policy model
chosen: Chosen sequence tokens
rejected: Rejected sequence tokens
chosen_masks: Attention masks for chosen sequences
rejected_masks: Attention masks for rejected sequences
chosen_rewards: Rewards for chosen sequences (derived from preference scores)
rejected_rewards: Rewards for rejected sequences (derived from preference scores)
beta: Temperature parameter
reward_scaling: Scaling factor for rewards
Returns:
Loss value, rewards, and number of tokens.
"""
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
# Calculate log probabilities for policy model policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
policy_chosen_scores = make_predictions(model, chosen, chosen_masks) policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
# Scale the pre-computed rewards mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
chosen_rewards = chosen_rewards * reward_scaling )
rejected_rewards = rejected_rewards * reward_scaling
ratio = nn.log_sigmoid(log_odds)
# Calculate reward difference loss = -beta * ratio
reward_diff = chosen_rewards - rejected_rewards
accuracies = (log_odds > 0).astype(mx.float32)
# Calculate ORPO loss using logistic function margins = mx.mean(ratio)
policy_diff = policy_chosen_scores.sum(-1) - policy_rejected_scores.sum(-1) metrics = {
loss = -nn.log_sigmoid(beta * (policy_diff * reward_diff)) 'accuracies': mx.mean(accuracies),
'margins': margins,
loss = mx.mean(loss) 'policy_rejected_logps': mx.mean(policy_rejected_logps),
'policy_chosen_logps': mx.mean(policy_chosen_logps),
# Calculate number of tokens for logging 'rejected_logits_mean': mx.mean(rejected_logits_mean),
num_tokens = (chosen_masks.sum() + rejected_masks.sum()) 'chosen_logits_mean': mx.mean(chosen_logits_mean)
}
# Calculate rewards for logging
avg_chosen_reward = mx.mean(chosen_rewards) chosen_reward = beta * policy_chosen_logps
avg_rejected_reward = mx.mean(rejected_rewards) rejected_reward = beta * policy_rejected_logps
reward = mx.stack([avg_chosen_reward, avg_rejected_reward]) reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
num_tokens = chosen_masks.sum() + rejected_masks.sum()
return loss, reward, num_tokens
return mx.mean(loss), reward, num_tokens, metrics
def evaluate_orpo(
model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
reward_scaling: float = 1.0,
max_seq_length=2048,
):
"""
Evaluation function for ORPO.
"""
all_losses = 0
all_rewards = mx.zeros((2,))
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
loss, reward, toks = orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=beta,
reward_scaling=reward_scaling,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item(), all_rewards.tolist()
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
@ -188,10 +108,6 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
# Get preference scores and convert to rewards # Get preference scores and convert to rewards
preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32) preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32)
# Convert preference scores to chosen/rejected rewards
# When preference_score is 1.0, chosen_reward=1.0, rejected_reward=0.0
# When preference_score is 0.0, chosen_reward=0.0, rejected_reward=1.0
# When preference_score is 0.5, both rewards are 0.5
chosen_rewards = preference_scores chosen_rewards = preference_scores
rejected_rewards = 1.0 - preference_scores rejected_rewards = 1.0 - preference_scores
@ -218,6 +134,56 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
break break
def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048):
all_losses = 0
all_rewards = mx.zeros((2,))
all_metrics = None
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
loss, reward, toks, metrics = orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=beta
)
all_losses += loss * toks
all_rewards += reward * toks
ntokens += toks
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_rewards = (all_rewards / ntokens).tolist()
avg_loss = (all_losses / ntokens).item()
return avg_loss, avg_rewards, ntokens, avg_metrics
def train_orpo( def train_orpo(
model, model,
tokenizer, tokenizer,
@ -227,9 +193,6 @@ def train_orpo(
args: ORPOTrainingArgs = ORPOTrainingArgs(), args: ORPOTrainingArgs = ORPOTrainingArgs(),
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
"""
Training function for ORPO.
"""
print(f"Starting ORPO training..., iters: {args.iters}") print(f"Starting ORPO training..., iters: {args.iters}")
world = mx.distributed.init() world = mx.distributed.init()
world_size = world.size() world_size = world.size()
@ -246,7 +209,7 @@ def train_orpo(
def step(batch): def step(batch):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
(loss, reward, toks), grad = loss_value_and_grad( (loss, reward, toks, metrics), grad = loss_value_and_grad(
model, model,
chosen, chosen,
rejected, rejected,
@ -259,7 +222,7 @@ def train_orpo(
grad = average_gradients(grad) grad = average_gradients(grad)
optimizer.update(model, grad) optimizer.update(model, grad)
return loss, reward, toks return loss, reward, toks, metrics
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
chosen_rewards, rejected_rewards): chosen_rewards, rejected_rewards):
@ -271,8 +234,7 @@ def train_orpo(
rejected_masks=rejected_masks, rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards, chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards, rejected_rewards=rejected_rewards,
beta=args.beta, beta=args.beta
reward_scaling=args.reward_scaling
) )
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
@ -283,11 +245,19 @@ def train_orpo(
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
trained_tokens = 0 trained_tokens = 0
accumulated_metrics = {
'accuracies': 0,
'margins': 0,
'policy_rejected_logps': 0,
'policy_chosen_logps': 0,
'rejected_logits_mean': 0,
'chosen_logits_mean': 0
}
start = time.perf_counter() start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
range(1, args.iters + 1), range(1, args.iters + 1),
iterate_orpo_batches( # reuse DPO batch iterator iterate_orpo_batches(
dataset=train_dataset, dataset=train_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -295,18 +265,16 @@ def train_orpo(
train=True, train=True,
), ),
): ):
# Evaluate if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
val_loss, val_rewards = evaluate_orpo( val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo(
model=model, model=model,
dataset=val_dataset, dataset=val_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
batch_size=args.batch_size, batch_size=args.batch_size,
num_batches=args.val_batches, num_batches=args.val_batches,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
beta=args.beta, beta=args.beta
reward_scaling=args.reward_scaling,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
if rank == 0: if rank == 0:
@ -315,6 +283,8 @@ def train_orpo(
f"Val loss {val_loss:.8f}, " f"Val loss {val_loss:.8f}, "
f"Val chosen reward {val_rewards[0]:.3f}, " f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, " f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val accuracy {val_metrics['accuracies']:.3f}, "
f"Val margin {val_metrics['margins']:.3f}, "
f"Val took {val_time:.3f}s", f"Val took {val_time:.3f}s",
flush=True, flush=True,
) )
@ -325,25 +295,28 @@ def train_orpo(
"val_loss": val_loss, "val_loss": val_loss,
"val_chosen_reward": val_rewards[0], "val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1], "val_rejected_reward": val_rewards[1],
**{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time, "val_time": val_time,
}) })
start = time.perf_counter() start = time.perf_counter()
# Training step # Training step
loss, reward, toks = step(batch) loss, reward, toks, metrics = step(batch)
losses += loss losses += loss
rewards += reward rewards += reward
n_tokens += toks n_tokens += toks
steps += 1 steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, rewards, n_tokens) mx.eval(state, losses, rewards, n_tokens)
# Report training metrics if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size) train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()] train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()]
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
n_tokens = mx.distributed.all_sum(n_tokens).item() n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
@ -356,10 +329,11 @@ def train_orpo(
f"Iter {it}: Train loss {train_loss:.8f}, " f"Iter {it}: Train loss {train_loss:.8f}, "
f"Chosen reward {train_rewards[0]:.3f}, " f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, " f"Rejected reward {train_rewards[1]:.3f}, "
f"Accuracy {avg_metrics['accuracies']:.3f}, "
f"Margin {avg_metrics['margins']:.3f}, "
f"Learning Rate {learning_rate:.3e}, " f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB", f"Peak mem {peak_mem:.3f} GB",
flush=True, flush=True,
) )
@ -370,6 +344,7 @@ def train_orpo(
"train_loss": train_loss, "train_loss": train_loss,
"train_chosen_reward": train_rewards[0], "train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1], "train_rejected_reward": train_rewards[1],
**{f"train_{k}": v for k, v in avg_metrics.items()},
"learning_rate": learning_rate, "learning_rate": learning_rate,
"iterations_per_second": it_sec, "iterations_per_second": it_sec,
"tokens_per_second": tokens_sec, "tokens_per_second": tokens_sec,
@ -381,9 +356,9 @@ def train_orpo(
rewards = mx.zeros((2,)) rewards = mx.zeros((2,))
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
accumulated_metrics = {k: 0 for k in accumulated_metrics}
start = time.perf_counter() start = time.perf_counter()
# Save model weights if needed
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters())) adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights) mx.save_safetensors(str(args.adapter_file), adapter_weights)
@ -396,7 +371,6 @@ def train_orpo(
f"{args.adapter_file} and {checkpoint}." f"{args.adapter_file} and {checkpoint}."
) )
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters())) adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights) mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.") print(f"Saved final weights to {args.adapter_file}.")