From 06f9c29c940d31be83e50759415dcf11157f899f Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 3 Feb 2025 19:47:40 +0100 Subject: [PATCH] print func name --- llms/mlx_lm/tuner/grpo_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index a8b3a1c9..a1b8fcbd 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -309,13 +309,14 @@ def grpo_loss( # Collect reward metrics reward_metrics = {} for i, reward_func in enumerate(reward_funcs): + func_name = reward_func.__name__ func_rewards = mx.array(reward_func( prompts=expanded_prompts, completions=all_completion_texts, answer=expanded_answers )) - reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) - reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) + reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) + reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) # Clean up del all_completions