diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 53a32b91..63671ec4 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -316,7 +316,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set test_loss, _, _, test_metrics = evaluate_dpo( model=model, - ref_model=reference_model, + ref_model=reference_model.freeze(), dataset=test_set, batch_size=args.batch_size, num_batches=args.test_batches,