From 1f894532953bcc23fb447b3c0bb1e1d7b38d37f3 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Mar 2025 14:00:51 +0100 Subject: [PATCH] eos token return fix --- llms/mlx_lm/tuner/grpo_trainer.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 34389c53..695c32fd 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -225,6 +225,11 @@ def generate_grpo( ) ): print(token) + + # Check for EOS token + if token == tokenizer.eos_token_id: + break + current_tokens.append(token) # Check for end token @@ -233,10 +238,6 @@ def generate_grpo( ): break - # Check for EOS token - if token == tokenizer.eos_token_id: - break - # Check if we've reached the maximum number of tokens if i >= max_tokens - 1: break @@ -305,7 +306,6 @@ def grpo_loss( max_tokens: int = 64, temperature: float = 0.8, reward_weights: Optional[List[float]] = None, - is_validation: bool = False, batch_size: int = 1, ): prompt_tokens, _, prompt_text, answer_text = batch @@ -366,7 +366,7 @@ def grpo_loss( continue # Restore original training state if we're not in validation mode - if not is_validation and was_training: + if was_training: model.train() mx.metal.clear_cache() @@ -712,8 +712,7 @@ def evaluate_grpo( epsilon=epsilon, ref_model=ref_model, temperature=temperature, - max_tokens=max_tokens, - is_validation=True, + max_tokens=max_tokens ) all_losses += losses * toks