From 541f0be9376642b604b5a1f6b7a179567a9f445e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 17 Feb 2025 14:39:38 +0100 Subject: [PATCH] fix generation cutoff in evaluation --- llms/mlx_lm/tuner/grpo_trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index e75da0fd..ea59ed06 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -403,7 +403,8 @@ def evaluate_grpo( beta: float, epsilon: float, group_size: int, - max_seq_length, + max_seq_length: int, + max_tokens: int, temperature: float, reward_funcs: Optional[List[RewardFunctions]] = None, loss_fn: callable = grpo_loss, @@ -432,7 +433,8 @@ def evaluate_grpo( group_size=group_size, epsilon=epsilon, ref_model=ref_model, - temperature=temperature + temperature=temperature, + max_tokens=max_tokens ) all_losses += losses * toks @@ -548,6 +550,7 @@ def train_grpo( batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, + max_tokens=args.max_completion_length, beta=args.beta, epsilon=args.epsilon, temperature=args.temperature,