fix batch size

This commit is contained in:
Goekdeniz-Guelmez 2025-03-09 00:26:41 +01:00
parent e88f0fad4b
commit f1961f1b79

View File

@ -171,7 +171,7 @@ def generate_grpo(
break
current_tokens.append(token)
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
mx.array(current_tokens[-len(end_sequence):]), end_sequence
):
@ -623,7 +623,8 @@ def train_grpo(
prompt_tokens=prompt_tokens,
max_tokens=args.max_completion_length,
group_size=args.group_size,
temperature=args.temperature
temperature=args.temperature,
batch_size=args.batch_size
)
(loss, toks, metrics), grad = loss_value_and_grad(