mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-04 15:56:16 +08:00
fix generation cutoff in evaluation
This commit is contained in:
parent
1eea135a20
commit
541f0be937
@ -403,7 +403,8 @@ def evaluate_grpo(
|
|||||||
beta: float,
|
beta: float,
|
||||||
epsilon: float,
|
epsilon: float,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
max_seq_length,
|
max_seq_length: int,
|
||||||
|
max_tokens: int,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||||
loss_fn: callable = grpo_loss,
|
loss_fn: callable = grpo_loss,
|
||||||
@ -432,7 +433,8 @@ def evaluate_grpo(
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
temperature=temperature
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
@ -548,6 +550,7 @@ def train_grpo(
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
|
max_tokens=args.max_completion_length,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
epsilon=args.epsilon,
|
epsilon=args.epsilon,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
|
Loading…
Reference in New Issue
Block a user