This commit is contained in:
Goekdeniz-Guelmez
2025-02-11 09:09:28 +01:00
parent 88ca747e9e
commit e96afe9e9f
2 changed files with 3 additions and 3 deletions

View File

@@ -447,8 +447,8 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages