From d723ddfedad401ccd7f522891a1cd003c1fabc60 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Mar 2025 14:49:56 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 098580b1..feb27737 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -343,6 +343,7 @@ def grpo_loss( # Convert to tensor prompt_tensor = mx.array(padded_prompts) + prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input try: mx.metal.clear_cache()