fix: prevent gradients from flowing through the reference model's logits

This commit is contained in:
Goekdeniz-Guelmez 2025-02-09 17:02:58 +01:00
parent 54179901b5
commit a527cdb39b

View File

@ -189,7 +189,7 @@ def compute_kl(logprobs1, logprobs2):
def compute_policy_ratio(current_logprobs, ref_logprobs):
return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32))
return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32))
def grpo_loss(