diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 098580b1..feb27737 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -343,6 +343,7 @@ def grpo_loss( # Convert to tensor prompt_tensor = mx.array(padded_prompts) + prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input try: mx.metal.clear_cache()