This commit is contained in:
Goekdeniz-Guelmez 2025-02-11 09:26:43 +01:00
parent e96afe9e9f
commit e80bf95182

View File

@ -447,8 +447,8 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu)
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages