mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
small fix
This commit is contained in:
parent
a03d434bb9
commit
fbb51f651a
@ -271,7 +271,7 @@ def train_model(
|
|||||||
|
|
||||||
train_dpo(
|
train_dpo(
|
||||||
model=model,
|
model=model,
|
||||||
reference_model=reference_model.freeze(),
|
ref_model=reference_model.freeze(),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
optimizer=opt,
|
optimizer=opt,
|
||||||
train_dataset=train_set,
|
train_dataset=train_set,
|
||||||
@ -314,7 +314,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
|||||||
|
|
||||||
test_loss, test_rewards = evaluate_dpo(
|
test_loss, test_rewards = evaluate_dpo(
|
||||||
model=model,
|
model=model,
|
||||||
reference_model=reference_model,
|
ref_model=reference_model,
|
||||||
dataset=test_set,
|
dataset=test_set,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
|
@ -273,7 +273,7 @@ def train_dpo(
|
|||||||
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
|
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
|
||||||
return loss(
|
return loss(
|
||||||
model=model,
|
model=model,
|
||||||
reference_teacher_model=ref_model,
|
ref_model=ref_model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
rejected=rejected,
|
rejected=rejected,
|
||||||
chosen_masks=chosen_masks,
|
chosen_masks=chosen_masks,
|
||||||
@ -313,7 +313,7 @@ def train_dpo(
|
|||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
||||||
model=model,
|
model=model,
|
||||||
reference_model=ref_model,
|
ref_model=ref_model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
|
Loading…
Reference in New Issue
Block a user