mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 17:58:54 +08:00
smoll fix
This commit is contained in:
@@ -279,6 +279,7 @@ def grpo_loss(
|
||||
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
||||
else:
|
||||
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
|
||||
# Reshape rewards and compute advantages
|
||||
|
||||
Reference in New Issue
Block a user