diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 4ddc3d2e..ed955e01 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -128,7 +128,6 @@ def dpo_loss( 'chosen_logits_mean': mx.mean(policy_chosen_score) } - return mx.mean(losses), reward, num_tokens, metrics @@ -180,7 +179,6 @@ def evaluate_dpo( model, reference_model, dataset, - tokenizer, batch_size, num_batches, beta: float, @@ -328,7 +326,6 @@ def train_dpo( model=model, reference_model=reference_model, dataset=val_dataset, - tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length,