From fbb51f651a2f9682e3929b8702a32edcbcb39b83 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 1 Feb 2025 16:08:52 +0100 Subject: [PATCH] small fix --- llms/mlx_lm/lora.py | 4 ++-- llms/mlx_lm/tuner/dpo_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index b68c99ab..2791db2f 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -271,7 +271,7 @@ def train_model( train_dpo( model=model, - reference_model=reference_model.freeze(), + ref_model=reference_model.freeze(), tokenizer=tokenizer, optimizer=opt, train_dataset=train_set, @@ -314,7 +314,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set test_loss, test_rewards = evaluate_dpo( model=model, - reference_model=reference_model, + ref_model=reference_model, dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size, diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 2f4a74b6..99272e81 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -273,7 +273,7 @@ def train_dpo( def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): return loss( model=model, - reference_teacher_model=ref_model, + ref_model=ref_model, chosen=chosen, rejected=rejected, chosen_masks=chosen_masks, @@ -313,7 +313,7 @@ def train_dpo( stop = time.perf_counter() val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( model=model, - reference_model=ref_model, + ref_model=ref_model, dataset=val_dataset, batch_size=args.batch_size, num_batches=args.val_batches,