From fab2dc26888ce30bca3bf99f2e98fc256304712d Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 26 Feb 2025 15:21:57 +0100 Subject: [PATCH] smoll fix --- llms/mlx_lm/tuner/grpo_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index d41bedce..66abf99f 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -279,7 +279,8 @@ def grpo_loss( reward_weights = mx.array(reward_weights, dtype=mx.float32) else: reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32) - rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) + + rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) # Reshape rewards and compute advantages rewards_reshaped = rewards.reshape(batch_size, group_size)