mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-02 06:41:13 +08:00
fix generation cutoff in evaluation
This commit is contained in:
parent
1eea135a20
commit
541f0be937
@ -403,7 +403,8 @@ def evaluate_grpo(
|
||||
beta: float,
|
||||
epsilon: float,
|
||||
group_size: int,
|
||||
max_seq_length,
|
||||
max_seq_length: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||
loss_fn: callable = grpo_loss,
|
||||
@ -432,7 +433,8 @@ def evaluate_grpo(
|
||||
group_size=group_size,
|
||||
epsilon=epsilon,
|
||||
ref_model=ref_model,
|
||||
temperature=temperature
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
all_losses += losses * toks
|
||||
@ -548,6 +550,7 @@ def train_grpo(
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
max_tokens=args.max_completion_length,
|
||||
beta=args.beta,
|
||||
epsilon=args.epsilon,
|
||||
temperature=args.temperature,
|
||||
|
Loading…
Reference in New Issue
Block a user