From a527cdb39b85c8bb25642b6c1f9f66a225c49907 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 9 Feb 2025 17:02:58 +0100 Subject: [PATCH] fix: prevent gradients from flowing through the reference model's logits --- llms/mlx_lm/tuner/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index b9d58c01..b7bdc7dc 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -189,7 +189,7 @@ def compute_kl(logprobs1, logprobs2): def compute_policy_ratio(current_logprobs, ref_logprobs): - return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32)) + return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32)) def grpo_loss(