print func name

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 19:47:40 +01:00
parent 40bca770ae
commit 06f9c29c94

View File

@ -309,13 +309,14 @@ def grpo_loss(
# Collect reward metrics # Collect reward metrics
reward_metrics = {} reward_metrics = {}
for i, reward_func in enumerate(reward_funcs): for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
func_rewards = mx.array(reward_func( func_rewards = mx.array(reward_func(
prompts=expanded_prompts, prompts=expanded_prompts,
completions=all_completion_texts, completions=all_completion_texts,
answer=expanded_answers answer=expanded_answers
)) ))
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
# Clean up # Clean up
del all_completions del all_completions