From 326935be49b8fc6740e857bb1b20d4eecf62b1f6 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Mar 2025 14:40:23 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_trainer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 9375a757..28546cae 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -175,8 +175,6 @@ def generate_grpo( try: import time - model.freeze() - start_time = time.time() if len(prompts.shape) == 1: @@ -213,7 +211,6 @@ def generate_grpo( sample_start_time = time.time() current_tokens = [] prompt_cache = cache.make_prompt_cache(model) - mx.eval(current_tokens, prompt_cache) # The generate_step function yields one token at a time # We'll collect tokens until we hit max_tokens or a stopping condition @@ -226,14 +223,13 @@ def generate_grpo( prompt_cache=prompt_cache, ) ): - print(token) - # Check for EOS token if token == tokenizer.eos_token_id: break current_tokens.append(token) - mx.eval(current_tokens[-1]) + + print(token) # Check for end token if len(current_tokens) >= len(end_sequence) and mx.array_equal( @@ -245,6 +241,8 @@ def generate_grpo( if i >= max_tokens - 1: break + mx.eval(current_tokens) + if current_tokens: results.append(mx.array(current_tokens)) total_tokens_generated += len(current_tokens) @@ -273,7 +271,6 @@ def generate_grpo( results = [mx.stop_gradient(r) for r in results] mx.eval(results) - model.unfreeze() return results except Exception as e: @@ -885,6 +882,8 @@ def train_grpo( n_tokens += toks steps += 1 + mx.metal.clear_cache() + for k, v in metrics.items(): accumulated_metrics[k] += v