fix generation cutoff in evaluation

This commit is contained in:
Goekdeniz-Guelmez 2025-02-17 14:39:38 +01:00
parent 1eea135a20
commit 541f0be937

View File

@ -403,7 +403,8 @@ def evaluate_grpo(
beta: float,
epsilon: float,
group_size: int,
max_seq_length,
max_seq_length: int,
max_tokens: int,
temperature: float,
reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss,
@ -432,7 +433,8 @@ def evaluate_grpo(
group_size=group_size,
epsilon=epsilon,
ref_model=ref_model,
temperature=temperature
temperature=temperature,
max_tokens=max_tokens
)
all_losses += losses * toks
@ -548,6 +550,7 @@ def train_grpo(
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
max_tokens=args.max_completion_length,
beta=args.beta,
epsilon=args.epsilon,
temperature=args.temperature,