mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 04:31:13 +08:00
fix name funcs
This commit is contained in:
parent
06f9c29c94
commit
54e295ea80
@ -574,12 +574,12 @@ def train_grpo(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add reward function specific metrics
|
# Add reward function specific metrics
|
||||||
for i in range(len(reward_funcs)):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
val_metrics_str += (
|
val_metrics_str += (
|
||||||
f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, "
|
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||||
f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}"
|
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: {val_metrics_str}, "
|
f"Iter {it}: {val_metrics_str}, "
|
||||||
f"Val took {val_time:.3f}s",
|
f"Val took {val_time:.3f}s",
|
||||||
@ -630,10 +630,11 @@ def train_grpo(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add reward function specific metrics
|
# Add reward function specific metrics
|
||||||
for i in range(len(reward_funcs)):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
|
func_name = reward_func.__name__
|
||||||
train_metrics_str += (
|
train_metrics_str += (
|
||||||
f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, "
|
f", Reward func {reward_func.__name__} mean {avg_metrics[f'reward_func_{reward_func.__name__}_mean']:.3f}, "
|
||||||
f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}"
|
f"Reward func {reward_func.__name__} std {avg_metrics[f'reward_func_{reward_func.__name__}_std']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user