fix name funcs

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 19:56:11 +01:00
parent 06f9c29c94
commit 54e295ea80

View File

@ -574,12 +574,12 @@ def train_grpo(
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
for i, reward_func in enumerate(reward_funcs):
val_metrics_str += (
f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}"
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
)
print(
f"Iter {it}: {val_metrics_str}, "
f"Val took {val_time:.3f}s",
@ -630,10 +630,11 @@ def train_grpo(
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
train_metrics_str += (
f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}"
f", Reward func {reward_func.__name__} mean {avg_metrics[f'reward_func_{reward_func.__name__}_mean']:.3f}, "
f"Reward func {reward_func.__name__} std {avg_metrics[f'reward_func_{reward_func.__name__}_std']:.3f}"
)
print(