mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix batch size
This commit is contained in:
parent
e88f0fad4b
commit
f1961f1b79
@ -171,7 +171,7 @@ def generate_grpo(
|
|||||||
break
|
break
|
||||||
|
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
|
|
||||||
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
||||||
mx.array(current_tokens[-len(end_sequence):]), end_sequence
|
mx.array(current_tokens[-len(end_sequence):]), end_sequence
|
||||||
):
|
):
|
||||||
@ -623,7 +623,8 @@ def train_grpo(
|
|||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
max_tokens=args.max_completion_length,
|
max_tokens=args.max_completion_length,
|
||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
temperature=args.temperature
|
temperature=args.temperature,
|
||||||
|
batch_size=args.batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||||
|
Loading…
Reference in New Issue
Block a user