From e80bf9518278ce189cf55968d2cd897d10e1a15e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 11 Feb 2025 09:26:43 +0100 Subject: [PATCH] fix --- llms/mlx_lm/tuner/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 40f6c709..4a1e6bbf 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -447,8 +447,8 @@ def evaluate_grpo( mx.eval(all_losses, ntokens) # Aggregate across distributed workers - all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu) + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} # Calculate averages