diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index a8b3a1c9..a1b8fcbd 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -309,13 +309,14 @@ def grpo_loss( # Collect reward metrics reward_metrics = {} for i, reward_func in enumerate(reward_funcs): + func_name = reward_func.__name__ func_rewards = mx.array(reward_func( prompts=expanded_prompts, completions=all_completion_texts, answer=expanded_answers )) - reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) - reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) + reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) + reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) # Clean up del all_completions