mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 03:41:17 +08:00
fix
This commit is contained in:
parent
94dcd0f63e
commit
9ba6146a76
@ -302,7 +302,7 @@ def grpo_loss(
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs))
|
||||
|
||||
# Compute per-token loss following GRPO formula
|
||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user