mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix batch size
This commit is contained in:
parent
e88f0fad4b
commit
f1961f1b79
@ -171,7 +171,7 @@ def generate_grpo(
|
||||
break
|
||||
|
||||
current_tokens.append(token)
|
||||
|
||||
|
||||
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
||||
mx.array(current_tokens[-len(end_sequence):]), end_sequence
|
||||
):
|
||||
@ -623,7 +623,8 @@ def train_grpo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_tokens=args.max_completion_length,
|
||||
group_size=args.group_size,
|
||||
temperature=args.temperature
|
||||
temperature=args.temperature,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
|
||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||
|
Loading…
Reference in New Issue
Block a user