diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 69603702..adc9363c 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -375,10 +375,22 @@ def grpo_loss( mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)) ) - # Compute per-token loss - per_token_loss = -( - (policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask - ) + # Apply PPO like clipping + policy_ratio_cliped = mx.clip(policy_ratio, 1 - epsilon, 1 + epsilon) + + # Calculate both unclipped and clipped objectives + unclipped_obj = policy_ratio * advantages.reshape(-1, 1) + clipped_obj = policy_ratio_cliped * advantages.reshape(-1, 1) + + # Take the minimum (pessimistic bound) + per_token_loss = -mx.minimum(unclipped_obj, clipped_obj) + + # Add KL penalty if beta is non-zero + if beta != 0.0: + per_token_loss = per_token_loss + beta * kl_div + + + per_token_loss = per_token_loss * length_mask # Average over tokens loss = (per_token_loss * length_mask).sum() / length_mask.sum() # Matches the pytorch implementaiton