From 2f2ddd4811b281c6e6cbc8d02f0226edfd882bf7 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 26 Jan 2025 15:17:06 +0100 Subject: [PATCH] clean up --- llms/mlx_lm/tuner/orpo_trainer.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py index 7aa07068..fb38c1e1 100644 --- a/llms/mlx_lm/tuner/orpo_trainer.py +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -32,7 +32,6 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_ policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks) policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks) - # Apply preference scores policy_chosen_logps = policy_chosen_logps * preference_scores log_odds = (policy_chosen_logps - policy_rejected_logps) - ( @@ -42,25 +41,25 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_ ratio = nn.log_sigmoid(log_odds) loss = -beta * ratio - metrics = { - 'accuracies': mx.mean((log_odds > 0).astype(mx.float32)), - 'margins': mx.mean(ratio - 1), - 'policy_rejected_logps': mx.mean(policy_rejected_logps), - 'policy_chosen_logps': mx.mean(policy_chosen_logps), - 'rejected_logits_mean': mx.mean(rejected_logits_mean), - 'chosen_logits_mean': mx.mean(chosen_logits_mean) - } - chosen_reward = beta * policy_chosen_logps rejected_reward = beta * policy_rejected_logps reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)]) num_tokens = chosen_masks.sum() + rejected_masks.sum() + metrics = { + 'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)), + 'margins': mx.mean(chosen_reward - rejected_reward), + 'policy_rejected_logps': mx.mean(policy_rejected_logps), + 'policy_chosen_logps': mx.mean(policy_chosen_logps), + 'rejected_logits_mean': mx.mean(rejected_logits_mean), + 'chosen_logits_mean': mx.mean(chosen_logits_mean) + } + return mx.mean(loss), reward, num_tokens, metrics -def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False): +def iterate_orpo_batches(dataset, batch_size, max_seq_length, train=False): """Batch iterator for ORPO with preference scores""" idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen'])) @@ -236,7 +235,6 @@ def train_orpo( range(1, args.iters + 1), iterate_orpo_batches( dataset=train_dataset, - tokenizer=tokenizer, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, @@ -247,7 +245,6 @@ def train_orpo( val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo( model=model, dataset=val_dataset, - tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length,