diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 31b6dc58..a5307bd3 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -171,7 +171,7 @@ def generate_grpo( break current_tokens.append(token) - + if len(current_tokens) >= len(end_sequence) and mx.array_equal( mx.array(current_tokens[-len(end_sequence):]), end_sequence ): @@ -623,7 +623,8 @@ def train_grpo( prompt_tokens=prompt_tokens, max_tokens=args.max_completion_length, group_size=args.group_size, - temperature=args.temperature + temperature=args.temperature, + batch_size=args.batch_size ) (loss, toks, metrics), grad = loss_value_and_grad(