diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index f6dfc830..a8b3a1c9 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -271,12 +271,12 @@ def grpo_loss( rewards = mx.zeros((len(all_completions),)) for reward_func in reward_funcs: func_rewards = mx.array(reward_func( - prompts=prompt_text, + prompts=expanded_prompts, completions=all_completion_texts, - answer=answer_text + answer=expanded_answers )) rewards += func_rewards - + if len(reward_funcs) > 1: rewards /= len(reward_funcs) @@ -310,9 +310,9 @@ def grpo_loss( reward_metrics = {} for i, reward_func in enumerate(reward_funcs): func_rewards = mx.array(reward_func( - prompts=prompt_text, + prompts=expanded_prompts, completions=all_completion_texts, - answer=answer_text + answer=expanded_answers )) reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)