diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 9375a757..28546cae 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -175,8 +175,6 @@ def generate_grpo( try: import time - model.freeze() - start_time = time.time() if len(prompts.shape) == 1: @@ -213,7 +211,6 @@ def generate_grpo( sample_start_time = time.time() current_tokens = [] prompt_cache = cache.make_prompt_cache(model) - mx.eval(current_tokens, prompt_cache) # The generate_step function yields one token at a time # We'll collect tokens until we hit max_tokens or a stopping condition @@ -226,14 +223,13 @@ def generate_grpo( prompt_cache=prompt_cache, ) ): - print(token) - # Check for EOS token if token == tokenizer.eos_token_id: break current_tokens.append(token) - mx.eval(current_tokens[-1]) + + print(token) # Check for end token if len(current_tokens) >= len(end_sequence) and mx.array_equal( @@ -245,6 +241,8 @@ def generate_grpo( if i >= max_tokens - 1: break + mx.eval(current_tokens) + if current_tokens: results.append(mx.array(current_tokens)) total_tokens_generated += len(current_tokens) @@ -273,7 +271,6 @@ def generate_grpo( results = [mx.stop_gradient(r) for r in results] mx.eval(results) - model.unfreeze() return results except Exception as e: @@ -885,6 +882,8 @@ def train_grpo( n_tokens += toks steps += 1 + mx.metal.clear_cache() + for k, v in metrics.items(): accumulated_metrics[k] += v