diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index f995a05c..c29b2f5d 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -302,7 +302,7 @@ def grpo_loss( length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) # Compute policy ratio - policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) + policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs)) # Compute per-token loss following GRPO formula per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)