From e96afe9e9f0ef6e756c7b4fb6fd2bc100d7a6ecd Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 11 Feb 2025 09:09:28 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/datasets.py | 2 +- llms/mlx_lm/tuner/grpo_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index e7c913b6..5f00d3e3 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -41,7 +41,7 @@ class GRPODataset: prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . - User: {prompt_str}. Assistant: """) + User: {prompt_str} Assistant: """) else: prompt_tokens = tokenizer.encode(prompt_str) answer_tokens = tokenizer.encode(answer_str) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 4a1e6bbf..40f6c709 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -447,8 +447,8 @@ def evaluate_grpo( mx.eval(all_losses, ntokens) # Aggregate across distributed workers - all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} # Calculate averages