diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 22797c7e..8a3590fa 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -176,12 +176,12 @@ def evaluate_dpo( num_batches, beta: float, delta: float, - max_seq_length=2048, - loss_fn: callable = dpo_loss, - loss_type="sigmoid", + max_seq_length, + loss_type, + loss_fn: callable = dpo_loss ): all_losses = 0 - all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward] + all_rewards = mx.zeros((2,)) ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -195,6 +195,7 @@ def evaluate_dpo( ), ): chosen, rejected, chosen_masks, rejected_masks = batch + loss, reward, toks = loss_fn( model=model, reference_teacher_model=reference_model, @@ -206,18 +207,18 @@ def evaluate_dpo( beta=beta, delta=delta, ) - all_losses += loss * toks all_rewards += reward ntokens += toks - mx.eval(all_losses, all_rewards, ntokens) all_losses = mx.distributed.all_sum(all_losses) all_rewards = mx.distributed.all_sum(all_rewards) - ntokens = mx.distributed.all_sum(ntokens) + ntokens = mx.distributed.all_sum(ntokens) + return (all_losses / ntokens).item(), all_rewards.tolist() + def train_dpo( model, reference_model,