diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index e75da0fd..ea59ed06 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -403,7 +403,8 @@ def evaluate_grpo( beta: float, epsilon: float, group_size: int, - max_seq_length, + max_seq_length: int, + max_tokens: int, temperature: float, reward_funcs: Optional[List[RewardFunctions]] = None, loss_fn: callable = grpo_loss, @@ -432,7 +433,8 @@ def evaluate_grpo( group_size=group_size, epsilon=epsilon, ref_model=ref_model, - temperature=temperature + temperature=temperature, + max_tokens=max_tokens ) all_losses += losses * toks @@ -548,6 +550,7 @@ def train_grpo( batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, + max_tokens=args.max_completion_length, beta=args.beta, epsilon=args.epsilon, temperature=args.temperature,