diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 1f9ef18a..ca6192ad 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -180,15 +180,6 @@ def get_per_token_logps(model, inputs, lengths): return per_token_logps -def compute_kl(logprobs1, logprobs2): - ratio = mx.exp(logprobs1 - logprobs2) - return ratio - 1 - (logprobs1 - logprobs2) - - -def compute_policy_ratio(current_logprobs, ref_logprobs): - return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32)) - - def grpo_loss( model, ref_model, @@ -302,13 +293,13 @@ def grpo_loss( advantages = (rewards - mean_rewards) / (std_rewards + epsilon) # Compute KL divergence using Schulman's approximator - kl_div = compute_kl(token_log_probs, ref_token_log_probs) + kl_div = mx.exp(token_log_probs - ref_token_log_probs) - (token_log_probs - ref_token_log_probs) - 1 # Create mask for valid tokens length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) # Compute policy ratio - policy_ratio = compute_policy_ratio(token_log_probs, ref_token_log_probs) + policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32)) # Compute per-token loss following GRPO formula per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)