mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
adding PPO like clipping adapted from trl
This commit is contained in:
parent
06ff47012f
commit
9fd6a5b6d0
@ -375,10 +375,22 @@ def grpo_loss(
|
||||
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
|
||||
)
|
||||
|
||||
# Compute per-token loss
|
||||
per_token_loss = -(
|
||||
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
||||
)
|
||||
# Apply PPO like clipping
|
||||
policy_ratio_cliped = mx.clip(policy_ratio, 1 - epsilon, 1 + epsilon)
|
||||
|
||||
# Calculate both unclipped and clipped objectives
|
||||
unclipped_obj = policy_ratio * advantages.reshape(-1, 1)
|
||||
clipped_obj = policy_ratio_cliped * advantages.reshape(-1, 1)
|
||||
|
||||
# Take the minimum (pessimistic bound)
|
||||
per_token_loss = -mx.minimum(unclipped_obj, clipped_obj)
|
||||
|
||||
# Add KL penalty if beta is non-zero
|
||||
if beta != 0.0:
|
||||
per_token_loss = per_token_loss + beta * kl_div
|
||||
|
||||
|
||||
per_token_loss = per_token_loss * length_mask
|
||||
|
||||
# Average over tokens
|
||||
loss = (per_token_loss * length_mask).sum() / length_mask.sum() # Matches the pytorch implementaiton
|
||||
|
Loading…
Reference in New Issue
Block a user