diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index e7c913b6..5f00d3e3 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -41,7 +41,7 @@ class GRPODataset:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
- User: {prompt_str}. Assistant: """)
+ User: {prompt_str} Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 4a1e6bbf..40f6c709 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -447,8 +447,8 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
- all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
- ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
+ all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu)
+ ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages