diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 68fa93da..1bbd7b63 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -295,13 +295,12 @@ def train_model( if args.reference_model_path: reference_model, _ = load(args.reference_model_path) - reference_model = reference_model.freeze() else: reference_model, _ = load(args.model) - + train_grpo( model=model, - ref_model=reference_model, + ref_model=reference_model.freeze(), tokenizer=tokenizer, optimizer=opt, train_dataset=train_set, @@ -340,11 +339,11 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set if args.reference_model_path: reference_model, _ = load(args.reference_model_path) else: - reference_model = model + reference_model, _ = load(args.model) test_loss, _, test_rewards = evaluate_grpo( model=model, - ref_model=reference_model, + ref_model=reference_model.freeze(), dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size,