This commit is contained in:
Goekdeniz-Guelmez 2025-02-09 14:32:50 +01:00
parent 94dcd0f63e
commit 9ba6146a76

View File

@ -302,7 +302,7 @@ def grpo_loss(
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs))
# Compute per-token loss following GRPO formula
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)