diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 9d938df8..8b20e71f 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -157,7 +157,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, current_input = expanded_prompts[idx] prompt_cache = cache.make_prompt_cache(model) - # Initial forward pass with the prompt logits = model(current_input[None], cache=prompt_cache)[:, -1] while len(current_tokens) < max_tokens: