mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
small fix
This commit is contained in:
parent
a03d434bb9
commit
fbb51f651a
@ -271,7 +271,7 @@ def train_model(
|
||||
|
||||
train_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model.freeze(),
|
||||
ref_model=reference_model.freeze(),
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
@ -314,7 +314,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
||||
|
||||
test_loss, test_rewards = evaluate_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
ref_model=reference_model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
|
@ -273,7 +273,7 @@ def train_dpo(
|
||||
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
|
||||
return loss(
|
||||
model=model,
|
||||
reference_teacher_model=ref_model,
|
||||
ref_model=ref_model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
@ -313,7 +313,7 @@ def train_dpo(
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
||||
model=model,
|
||||
reference_model=ref_model,
|
||||
ref_model=ref_model,
|
||||
dataset=val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
|
Loading…
Reference in New Issue
Block a user