fix batch size

This commit is contained in:
Goekdeniz-Guelmez 2025-03-09 00:26:41 +01:00
parent e88f0fad4b
commit f1961f1b79

View File

@ -623,7 +623,8 @@ def train_grpo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
max_tokens=args.max_completion_length, max_tokens=args.max_completion_length,
group_size=args.group_size, group_size=args.group_size,
temperature=args.temperature temperature=args.temperature,
batch_size=args.batch_size
) )
(loss, toks, metrics), grad = loss_value_and_grad( (loss, toks, metrics), grad = loss_value_and_grad(