small fix

This commit is contained in:
Goekdeniz-Guelmez 2025-01-31 17:19:55 +01:00
parent b31d9cbb65
commit b379359385

View File

@ -218,7 +218,7 @@ def evaluate_dpo(
for k, v in metrics.items():
all_metrics[k] += v * toks
mx.eval(all_losses, all_rewards, ntokens)
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)