mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 19:31:20 +08:00
fix: prevent gradients from flowing through the reference model's logits
This commit is contained in:
parent
54179901b5
commit
a527cdb39b
@ -189,7 +189,7 @@ def compute_kl(logprobs1, logprobs2):
|
|||||||
|
|
||||||
|
|
||||||
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
||||||
return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32))
|
return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32))
|
||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
|
Loading…
Reference in New Issue
Block a user