This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 19:43:49 +01:00
parent 05d921b788
commit 40bca770ae

View File

@ -271,9 +271,9 @@ def grpo_loss(
rewards = mx.zeros((len(all_completions),)) rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs: for reward_func in reward_funcs:
func_rewards = mx.array(reward_func( func_rewards = mx.array(reward_func(
prompts=prompt_text, prompts=expanded_prompts,
completions=all_completion_texts, completions=all_completion_texts,
answer=answer_text answer=expanded_answers
)) ))
rewards += func_rewards rewards += func_rewards
@ -310,9 +310,9 @@ def grpo_loss(
reward_metrics = {} reward_metrics = {}
for i, reward_func in enumerate(reward_funcs): for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func( func_rewards = mx.array(reward_func(
prompts=prompt_text, prompts=expanded_prompts,
completions=all_completion_texts, completions=all_completion_texts,
answer=answer_text answer=expanded_answers
)) ))
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)