mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
adding PPO like clipping adapted from trl
This commit is contained in:
parent
06ff47012f
commit
9fd6a5b6d0
@ -375,10 +375,22 @@ def grpo_loss(
|
|||||||
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
|
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute per-token loss
|
# Apply PPO like clipping
|
||||||
per_token_loss = -(
|
policy_ratio_cliped = mx.clip(policy_ratio, 1 - epsilon, 1 + epsilon)
|
||||||
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
|
||||||
)
|
# Calculate both unclipped and clipped objectives
|
||||||
|
unclipped_obj = policy_ratio * advantages.reshape(-1, 1)
|
||||||
|
clipped_obj = policy_ratio_cliped * advantages.reshape(-1, 1)
|
||||||
|
|
||||||
|
# Take the minimum (pessimistic bound)
|
||||||
|
per_token_loss = -mx.minimum(unclipped_obj, clipped_obj)
|
||||||
|
|
||||||
|
# Add KL penalty if beta is non-zero
|
||||||
|
if beta != 0.0:
|
||||||
|
per_token_loss = per_token_loss + beta * kl_div
|
||||||
|
|
||||||
|
|
||||||
|
per_token_loss = per_token_loss * length_mask
|
||||||
|
|
||||||
# Average over tokens
|
# Average over tokens
|
||||||
loss = (per_token_loss * length_mask).sum() / length_mask.sum() # Matches the pytorch implementaiton
|
loss = (per_token_loss * length_mask).sum() / length_mask.sum() # Matches the pytorch implementaiton
|
||||||
|
Loading…
Reference in New Issue
Block a user