mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-17 09:08:10 +08:00
fix: prevent gradients from flowing through the reference model's logits
This commit is contained in:
@@ -189,7 +189,7 @@ def compute_kl(logprobs1, logprobs2):
|
|||||||
|
|
||||||
|
|
||||||
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
||||||
return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32))
|
return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32))
|
||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
|
Reference in New Issue
Block a user