diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index a9584bda..34389c53 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -129,7 +129,7 @@ def generate_step( logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) - return y, logprobs.squeeze(0) + return mx.stop_gradient(y), mx.stop_gradient(logprobs.squeeze(0)) with mx.stream(generation_stream): total_prompt_tokens = y.size @@ -186,7 +186,7 @@ def generate_grpo( expanded_prompts = mx.repeat(prompts, group_size, axis=0) end_sequence = mx.array(tokenizer.encode(end_token)) results = [] - mx.eval(expanded_prompts) + mx.eval(expanded_prompts, results) print(f"Setup time: {time.time() - start_time:.2f}s") print(f"Generating {total_samples} samples with max_tokens={max_tokens}") @@ -211,6 +211,7 @@ def generate_grpo( sample_start_time = time.time() current_tokens = [] prompt_cache = cache.make_prompt_cache(model) + mx.eval(current_tokens, prompt_cache) # The generate_step function yields one token at a time # We'll collect tokens until we hit max_tokens or a stopping condition @@ -316,6 +317,7 @@ def grpo_loss( # Store original training state was_training = model.training + print(f"Was model in training mode: {was_training}") # Set model to eval mode for generation model.eval() @@ -352,9 +354,8 @@ def grpo_loss( for j, completion_ids in enumerate(completions): # Calculate which prompt this completion belongs to prompt_idx = i + (j // group_size) - if ( - prompt_idx < total_samples - ): # Make sure we don't go out of bounds + + if prompt_idx < total_samples: # Make sure we don't go out of bounds batch_indices.append(prompt_idx) completion_text = tokenizer.decode(completion_ids.tolist()) all_completions.append(completion_ids) @@ -367,7 +368,6 @@ def grpo_loss( # Restore original training state if we're not in validation mode if not is_validation and was_training: model.train() - mx.metal.clear_cache() # If we didn't generate any completions, return early @@ -376,7 +376,6 @@ def grpo_loss( "No completions were generated. Please check your model and inputs." ) - # The rest of the function remains the same # Create expanded prompts and answers based on actual generated completions expanded_answers = [] expanded_prompts = []