diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index b68c99ab..2791db2f 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -271,7 +271,7 @@ def train_model( train_dpo( model=model, - reference_model=reference_model.freeze(), + ref_model=reference_model.freeze(), tokenizer=tokenizer, optimizer=opt, train_dataset=train_set, @@ -314,7 +314,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set test_loss, test_rewards = evaluate_dpo( model=model, - reference_model=reference_model, + ref_model=reference_model, dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size, diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 2f4a74b6..99272e81 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -273,7 +273,7 @@ def train_dpo( def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): return loss( model=model, - reference_teacher_model=ref_model, + ref_model=ref_model, chosen=chosen, rejected=rejected, chosen_masks=chosen_masks, @@ -313,7 +313,7 @@ def train_dpo( stop = time.perf_counter() val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( model=model, - reference_model=ref_model, + ref_model=ref_model, dataset=val_dataset, batch_size=args.batch_size, num_batches=args.val_batches,