From 9fd6a5b6d0f05a9843d87f2934b2a05dabf32565 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 11 Mar 2025 09:08:38 +0100 Subject: [PATCH] adding PPO like clipping adapted from trl --- llms/mlx_lm/tuner/grpo_trainer.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 69603702..adc9363c 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -375,10 +375,22 @@ def grpo_loss( mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)) ) - # Compute per-token loss - per_token_loss = -( - (policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask - ) + # Apply PPO like clipping + policy_ratio_cliped = mx.clip(policy_ratio, 1 - epsilon, 1 + epsilon) + + # Calculate both unclipped and clipped objectives + unclipped_obj = policy_ratio * advantages.reshape(-1, 1) + clipped_obj = policy_ratio_cliped * advantages.reshape(-1, 1) + + # Take the minimum (pessimistic bound) + per_token_loss = -mx.minimum(unclipped_obj, clipped_obj) + + # Add KL penalty if beta is non-zero + if beta != 0.0: + per_token_loss = per_token_loss + beta * kl_div + + + per_token_loss = per_token_loss * length_mask # Average over tokens loss = (per_token_loss * length_mask).sum() / length_mask.sum() # Matches the pytorch implementaiton