mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-04 07:41:13 +08:00
fixes
This commit is contained in:
parent
05d921b788
commit
40bca770ae
@ -271,9 +271,9 @@ def grpo_loss(
|
|||||||
rewards = mx.zeros((len(all_completions),))
|
rewards = mx.zeros((len(all_completions),))
|
||||||
for reward_func in reward_funcs:
|
for reward_func in reward_funcs:
|
||||||
func_rewards = mx.array(reward_func(
|
func_rewards = mx.array(reward_func(
|
||||||
prompts=prompt_text,
|
prompts=expanded_prompts,
|
||||||
completions=all_completion_texts,
|
completions=all_completion_texts,
|
||||||
answer=answer_text
|
answer=expanded_answers
|
||||||
))
|
))
|
||||||
rewards += func_rewards
|
rewards += func_rewards
|
||||||
|
|
||||||
@ -310,9 +310,9 @@ def grpo_loss(
|
|||||||
reward_metrics = {}
|
reward_metrics = {}
|
||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
func_rewards = mx.array(reward_func(
|
func_rewards = mx.array(reward_func(
|
||||||
prompts=prompt_text,
|
prompts=expanded_prompts,
|
||||||
completions=all_completion_texts,
|
completions=all_completion_texts,
|
||||||
answer=answer_text
|
answer=expanded_answers
|
||||||
))
|
))
|
||||||
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||||
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||||
|
Loading…
Reference in New Issue
Block a user