mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 09:31:13 +08:00
print func name
This commit is contained in:
parent
40bca770ae
commit
06f9c29c94
@ -309,13 +309,14 @@ def grpo_loss(
|
||||
# Collect reward metrics
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers
|
||||
))
|
||||
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||
|
||||
# Clean up
|
||||
del all_completions
|
||||
|
Loading…
Reference in New Issue
Block a user