diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 40f6c709..4a1e6bbf 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -447,8 +447,8 @@ def evaluate_grpo( mx.eval(all_losses, ntokens) # Aggregate across distributed workers - all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu) + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} # Calculate averages