mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 19:31:20 +08:00
clean up
This commit is contained in:
parent
d8e7834345
commit
2f2ddd4811
@ -32,7 +32,6 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_
|
|||||||
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
|
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
|
||||||
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
|
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
|
||||||
|
|
||||||
# Apply preference scores
|
|
||||||
policy_chosen_logps = policy_chosen_logps * preference_scores
|
policy_chosen_logps = policy_chosen_logps * preference_scores
|
||||||
|
|
||||||
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
||||||
@ -42,25 +41,25 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_
|
|||||||
ratio = nn.log_sigmoid(log_odds)
|
ratio = nn.log_sigmoid(log_odds)
|
||||||
loss = -beta * ratio
|
loss = -beta * ratio
|
||||||
|
|
||||||
metrics = {
|
|
||||||
'accuracies': mx.mean((log_odds > 0).astype(mx.float32)),
|
|
||||||
'margins': mx.mean(ratio - 1),
|
|
||||||
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
|
||||||
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
|
||||||
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
|
||||||
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
|
||||||
}
|
|
||||||
|
|
||||||
chosen_reward = beta * policy_chosen_logps
|
chosen_reward = beta * policy_chosen_logps
|
||||||
rejected_reward = beta * policy_rejected_logps
|
rejected_reward = beta * policy_rejected_logps
|
||||||
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
|
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
|
||||||
|
|
||||||
num_tokens = chosen_masks.sum() + rejected_masks.sum()
|
num_tokens = chosen_masks.sum() + rejected_masks.sum()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)),
|
||||||
|
'margins': mx.mean(chosen_reward - rejected_reward),
|
||||||
|
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
||||||
|
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
||||||
|
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
||||||
|
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
||||||
|
}
|
||||||
|
|
||||||
return mx.mean(loss), reward, num_tokens, metrics
|
return mx.mean(loss), reward, num_tokens, metrics
|
||||||
|
|
||||||
|
|
||||||
def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False):
|
def iterate_orpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||||
"""Batch iterator for ORPO with preference scores"""
|
"""Batch iterator for ORPO with preference scores"""
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||||
|
|
||||||
@ -236,7 +235,6 @@ def train_orpo(
|
|||||||
range(1, args.iters + 1),
|
range(1, args.iters + 1),
|
||||||
iterate_orpo_batches(
|
iterate_orpo_batches(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
train=True,
|
train=True,
|
||||||
@ -247,7 +245,6 @@ def train_orpo(
|
|||||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo(
|
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo(
|
||||||
model=model,
|
model=model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
|
Loading…
Reference in New Issue
Block a user