diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index b9d58c01..b7bdc7dc 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -189,7 +189,7 @@ def compute_kl(logprobs1, logprobs2): def compute_policy_ratio(current_logprobs, ref_logprobs): - return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32)) + return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32)) def grpo_loss(