fixing reference model loading and freezing

This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 00:41:27 +01:00
parent 1ff788821c
commit 582f979dfd
2 changed files with 3 additions and 3 deletions

View File

@ -249,11 +249,11 @@ def train_model(
if args.reference_model_path: if args.reference_model_path:
reference_model, _ = load(args.reference_model_path) reference_model, _ = load(args.reference_model_path)
else: else:
reference_model = model reference_model, _ = load(args.model)
train_dpo( train_dpo(
model=model, model=model,
reference_model=reference_model, reference_model=reference_model.freeze(),
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer=opt, optimizer=opt,
train_dataset=train_set, train_dataset=train_set,