adding PPO like clipping adapted from trl

This commit is contained in:
Goekdeniz-Guelmez 2025-03-11 09:08:38 +01:00
parent 06ff47012f
commit 9fd6a5b6d0

View File

@ -375,10 +375,22 @@ def grpo_loss(
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)) mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
) )
# Compute per-token loss # Apply PPO like clipping
per_token_loss = -( policy_ratio_cliped = mx.clip(policy_ratio, 1 - epsilon, 1 + epsilon)
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
) # Calculate both unclipped and clipped objectives
unclipped_obj = policy_ratio * advantages.reshape(-1, 1)
clipped_obj = policy_ratio_cliped * advantages.reshape(-1, 1)
# Take the minimum (pessimistic bound)
per_token_loss = -mx.minimum(unclipped_obj, clipped_obj)
# Add KL penalty if beta is non-zero
if beta != 0.0:
per_token_loss = per_token_loss + beta * kl_div
per_token_loss = per_token_loss * length_mask
# Average over tokens # Average over tokens
loss = (per_token_loss * length_mask).sum() / length_mask.sum() # Matches the pytorch implementaiton loss = (per_token_loss * length_mask).sum() / length_mask.sum() # Matches the pytorch implementaiton