small fix

This commit is contained in:
Goekdeniz-Guelmez 2025-02-01 16:08:52 +01:00
parent a03d434bb9
commit fbb51f651a
2 changed files with 4 additions and 4 deletions

View File

@ -271,7 +271,7 @@ def train_model(
train_dpo( train_dpo(
model=model, model=model,
reference_model=reference_model.freeze(), ref_model=reference_model.freeze(),
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer=opt, optimizer=opt,
train_dataset=train_set, train_dataset=train_set,
@ -314,7 +314,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
test_loss, test_rewards = evaluate_dpo( test_loss, test_rewards = evaluate_dpo(
model=model, model=model,
reference_model=reference_model, ref_model=reference_model,
dataset=test_set, dataset=test_set,
tokenizer=tokenizer, tokenizer=tokenizer,
batch_size=args.batch_size, batch_size=args.batch_size,

View File

@ -273,7 +273,7 @@ def train_dpo(
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
return loss( return loss(
model=model, model=model,
reference_teacher_model=ref_model, ref_model=ref_model,
chosen=chosen, chosen=chosen,
rejected=rejected, rejected=rejected,
chosen_masks=chosen_masks, chosen_masks=chosen_masks,
@ -313,7 +313,7 @@ def train_dpo(
stop = time.perf_counter() stop = time.perf_counter()
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
model=model, model=model,
reference_model=ref_model, ref_model=ref_model,
dataset=val_dataset, dataset=val_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
num_batches=args.val_batches, num_batches=args.val_batches,