diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index ca6192ad..36b44ac2 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -322,6 +322,7 @@ def grpo_loss( answer=expanded_answers )) reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) + reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) metrics = { 'total_rewards_mean': mx.mean(rewards),