fixing reference model loading and freezing

This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 00:41:27 +01:00
parent 1ff788821c
commit 582f979dfd
2 changed files with 3 additions and 3 deletions

View File

@ -249,11 +249,11 @@ def train_model(
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model
reference_model, _ = load(args.model)
train_dpo(
model=model,
reference_model=reference_model,
reference_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,