From 9ba6146a762830152b3e2c7733c68b7e15c2a8bf Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 9 Feb 2025 14:32:50 +0100 Subject: [PATCH] fix --- llms/mlx_lm/tuner/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index f995a05c..c29b2f5d 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -302,7 +302,7 @@ def grpo_loss( length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) # Compute policy ratio - policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) + policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs)) # Compute per-token loss following GRPO formula per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)