diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index e96b8f29..e75da0fd 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -280,7 +280,6 @@ def grpo_loss( # Stack rewards to shape (num_samples, num_funcs) rewards = mx.stack(all_func_rewards, axis=1) - print(f"Rewards: {rewards}") # Apply weights and sum if reward_weights is not None: @@ -293,7 +292,6 @@ def grpo_loss( else: reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32) rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) - print(f"Rewards after weights: {rewards}") # Reshape rewards and compute advantages rewards_reshaped = rewards.reshape(batch_size, group_size) @@ -302,7 +300,7 @@ def grpo_loss( advantages = (rewards - mean_rewards) / (std_rewards + epsilon) # Compute KL divergence using Schulman's approximator - kl_div = mx.exp(token_log_probs - ref_token_log_probs) - (token_log_probs - ref_token_log_probs) - 1 + kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 # Create mask for valid tokens length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)