mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 19:31:20 +08:00
smoll fix
This commit is contained in:
parent
ef6ff92add
commit
fab2dc2688
@ -279,7 +279,8 @@ def grpo_loss(
|
|||||||
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
||||||
else:
|
else:
|
||||||
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
|
||||||
|
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||||
|
|
||||||
# Reshape rewards and compute advantages
|
# Reshape rewards and compute advantages
|
||||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user