mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
smoll fix
This commit is contained in:
parent
ef6ff92add
commit
fab2dc2688
@ -279,7 +279,8 @@ def grpo_loss(
|
||||
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
||||
else:
|
||||
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
|
||||
# Reshape rewards and compute advantages
|
||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||
|
Loading…
Reference in New Issue
Block a user