From f1961f1b793b2e75567f7ff787d7caaa9d8fffdd Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 9 Mar 2025 00:26:41 +0100 Subject: [PATCH] fix batch size --- llms/mlx_lm/tuner/grpo_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 31b6dc58..a5307bd3 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -171,7 +171,7 @@ def generate_grpo( break current_tokens.append(token) - + if len(current_tokens) >= len(end_sequence) and mx.array_equal( mx.array(current_tokens[-len(end_sequence):]), end_sequence ): @@ -623,7 +623,8 @@ def train_grpo( prompt_tokens=prompt_tokens, max_tokens=args.max_completion_length, group_size=args.group_size, - temperature=args.temperature + temperature=args.temperature, + batch_size=args.batch_size ) (loss, toks, metrics), grad = loss_value_and_grad(