This commit is contained in:
Goekdeniz-Guelmez 2025-01-26 15:17:06 +01:00
parent d8e7834345
commit 2f2ddd4811

View File

@ -32,7 +32,6 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks) policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks) policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
# Apply preference scores
policy_chosen_logps = policy_chosen_logps * preference_scores policy_chosen_logps = policy_chosen_logps * preference_scores
log_odds = (policy_chosen_logps - policy_rejected_logps) - ( log_odds = (policy_chosen_logps - policy_rejected_logps) - (
@ -42,25 +41,25 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_
ratio = nn.log_sigmoid(log_odds) ratio = nn.log_sigmoid(log_odds)
loss = -beta * ratio loss = -beta * ratio
metrics = {
'accuracies': mx.mean((log_odds > 0).astype(mx.float32)),
'margins': mx.mean(ratio - 1),
'policy_rejected_logps': mx.mean(policy_rejected_logps),
'policy_chosen_logps': mx.mean(policy_chosen_logps),
'rejected_logits_mean': mx.mean(rejected_logits_mean),
'chosen_logits_mean': mx.mean(chosen_logits_mean)
}
chosen_reward = beta * policy_chosen_logps chosen_reward = beta * policy_chosen_logps
rejected_reward = beta * policy_rejected_logps rejected_reward = beta * policy_rejected_logps
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)]) reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
num_tokens = chosen_masks.sum() + rejected_masks.sum() num_tokens = chosen_masks.sum() + rejected_masks.sum()
metrics = {
'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)),
'margins': mx.mean(chosen_reward - rejected_reward),
'policy_rejected_logps': mx.mean(policy_rejected_logps),
'policy_chosen_logps': mx.mean(policy_chosen_logps),
'rejected_logits_mean': mx.mean(rejected_logits_mean),
'chosen_logits_mean': mx.mean(chosen_logits_mean)
}
return mx.mean(loss), reward, num_tokens, metrics return mx.mean(loss), reward, num_tokens, metrics
def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False): def iterate_orpo_batches(dataset, batch_size, max_seq_length, train=False):
"""Batch iterator for ORPO with preference scores""" """Batch iterator for ORPO with preference scores"""
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen'])) idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
@ -236,7 +235,6 @@ def train_orpo(
range(1, args.iters + 1), range(1, args.iters + 1),
iterate_orpo_batches( iterate_orpo_batches(
dataset=train_dataset, dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size, batch_size=args.batch_size,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
train=True, train=True,
@ -247,7 +245,6 @@ def train_orpo(
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo( val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo(
model=model, model=model,
dataset=val_dataset, dataset=val_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size, batch_size=args.batch_size,
num_batches=args.val_batches, num_batches=args.val_batches,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,