small fix

This commit is contained in:
Goekdeniz-Guelmez
2025-02-01 16:08:52 +01:00
parent a03d434bb9
commit fbb51f651a
2 changed files with 4 additions and 4 deletions

View File

@@ -273,7 +273,7 @@ def train_dpo(
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
return loss(
model=model,
reference_teacher_model=ref_model,
ref_model=ref_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
@@ -313,7 +313,7 @@ def train_dpo(
stop = time.perf_counter()
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
model=model,
reference_model=ref_model,
ref_model=ref_model,
dataset=val_dataset,
batch_size=args.batch_size,
num_batches=args.val_batches,