smoll fix

This commit is contained in:
Goekdeniz-Guelmez 2025-02-26 15:21:57 +01:00
parent ef6ff92add
commit fab2dc2688

View File

@ -279,7 +279,8 @@ def grpo_loss(
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
# Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size)