diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 1a8ae42f..9e790273 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -218,7 +218,7 @@ def evaluate_dpo( for k, v in metrics.items(): all_metrics[k] += v * toks - mx.eval(all_losses, all_rewards, ntokens) + mx.eval(all_losses, all_rewards, ntokens) all_losses = mx.distributed.all_sum(all_losses) all_rewards = mx.distributed.all_sum(all_rewards) ntokens = mx.distributed.all_sum(ntokens)