From 2d2f39f96e25b656929c7ac762dad69d5655b958 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Mar 2025 14:25:55 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_trainer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 695c32fd..9375a757 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -175,6 +175,8 @@ def generate_grpo( try: import time + model.freeze() + start_time = time.time() if len(prompts.shape) == 1: @@ -229,8 +231,9 @@ def generate_grpo( # Check for EOS token if token == tokenizer.eos_token_id: break - + current_tokens.append(token) + mx.eval(current_tokens[-1]) # Check for end token if len(current_tokens) >= len(end_sequence) and mx.array_equal( @@ -268,7 +271,9 @@ def generate_grpo( ) print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec") + results = [mx.stop_gradient(r) for r in results] mx.eval(results) + model.unfreeze() return results except Exception as e: @@ -307,6 +312,7 @@ def grpo_loss( temperature: float = 0.8, reward_weights: Optional[List[float]] = None, batch_size: int = 1, + is_validation: bool = False ): prompt_tokens, _, prompt_text, answer_text = batch total_samples = len(prompt_tokens) @@ -321,6 +327,7 @@ def grpo_loss( # Set model to eval mode for generation model.eval() + print(f"Is model now in training mode: {model.training}") # Process in smaller batches for i in range(0, total_samples, batch_size): @@ -340,6 +347,7 @@ def grpo_loss( prompt_tensor = mx.array(padded_prompts) try: + mx.metal.clear_cache() completions = generate_grpo( model, prompt_tensor, @@ -363,6 +371,7 @@ def grpo_loss( mx.eval(completion_ids) except Exception as e: print(f"Generation error: {e}") + print(f"Is model in training mode after generation: {model.training}") continue # Restore original training state if we're not in validation mode @@ -712,7 +721,8 @@ def evaluate_grpo( epsilon=epsilon, ref_model=ref_model, temperature=temperature, - max_tokens=max_tokens + max_tokens=max_tokens, + is_validation=True ) all_losses += losses * toks