mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
small fix
This commit is contained in:
@@ -273,7 +273,7 @@ def train_dpo(
|
||||
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
|
||||
return loss(
|
||||
model=model,
|
||||
reference_teacher_model=ref_model,
|
||||
ref_model=ref_model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
@@ -313,7 +313,7 @@ def train_dpo(
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
||||
model=model,
|
||||
reference_model=ref_model,
|
||||
ref_model=ref_model,
|
||||
dataset=val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
|
||||
Reference in New Issue
Block a user