From 39e94690592cbc407cfe9940aa3c6503994aa968 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 9 Feb 2025 15:30:51 +0100 Subject: [PATCH] freeze ref model --- llms/mlx_lm/lora.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 68fa93da..1bbd7b63 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -295,13 +295,12 @@ def train_model( if args.reference_model_path: reference_model, _ = load(args.reference_model_path) - reference_model = reference_model.freeze() else: reference_model, _ = load(args.model) - + train_grpo( model=model, - ref_model=reference_model, + ref_model=reference_model.freeze(), tokenizer=tokenizer, optimizer=opt, train_dataset=train_set, @@ -340,11 +339,11 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set if args.reference_model_path: reference_model, _ = load(args.reference_model_path) else: - reference_model = model + reference_model, _ = load(args.model) test_loss, _, test_rewards = evaluate_grpo( model=model, - ref_model=reference_model, + ref_model=reference_model.freeze(), dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size,