mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 11:21:32 +08:00
clean up
This commit is contained in:
parent
d8e7834345
commit
2f2ddd4811
@ -32,7 +32,6 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_
|
||||
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
|
||||
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
|
||||
|
||||
# Apply preference scores
|
||||
policy_chosen_logps = policy_chosen_logps * preference_scores
|
||||
|
||||
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
||||
@ -42,25 +41,25 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_
|
||||
ratio = nn.log_sigmoid(log_odds)
|
||||
loss = -beta * ratio
|
||||
|
||||
metrics = {
|
||||
'accuracies': mx.mean((log_odds > 0).astype(mx.float32)),
|
||||
'margins': mx.mean(ratio - 1),
|
||||
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
||||
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
||||
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
||||
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
||||
}
|
||||
|
||||
chosen_reward = beta * policy_chosen_logps
|
||||
rejected_reward = beta * policy_rejected_logps
|
||||
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
|
||||
|
||||
num_tokens = chosen_masks.sum() + rejected_masks.sum()
|
||||
|
||||
metrics = {
|
||||
'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)),
|
||||
'margins': mx.mean(chosen_reward - rejected_reward),
|
||||
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
||||
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
||||
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
||||
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
||||
}
|
||||
|
||||
return mx.mean(loss), reward, num_tokens, metrics
|
||||
|
||||
|
||||
def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False):
|
||||
def iterate_orpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||
"""Batch iterator for ORPO with preference scores"""
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||
|
||||
@ -236,7 +235,6 @@ def train_orpo(
|
||||
range(1, args.iters + 1),
|
||||
iterate_orpo_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
@ -247,7 +245,6 @@ def train_orpo(
|
||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
|
Loading…
Reference in New Issue
Block a user