mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
fix
This commit is contained in:
parent
e96afe9e9f
commit
e80bf95182
@ -447,8 +447,8 @@ def evaluate_grpo(
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
# Aggregate across distributed workers
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu)
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||
|
||||
# Calculate averages
|
||||
|
Loading…
Reference in New Issue
Block a user