mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
fix: prevent gradients from flowing through the reference model's logits
This commit is contained in:
parent
54179901b5
commit
a527cdb39b
@ -189,7 +189,7 @@ def compute_kl(logprobs1, logprobs2):
|
||||
|
||||
|
||||
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
||||
return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32))
|
||||
return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32))
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
|
Loading…
Reference in New Issue
Block a user