fixing reference model loading and freezing

This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 00:41:27 +01:00
parent 1ff788821c
commit 582f979dfd
2 changed files with 3 additions and 3 deletions

View File

@ -249,11 +249,11 @@ def train_model(
if args.reference_model_path: if args.reference_model_path:
reference_model, _ = load(args.reference_model_path) reference_model, _ = load(args.reference_model_path)
else: else:
reference_model = model reference_model, _ = load(args.model)
train_dpo( train_dpo(
model=model, model=model,
reference_model=reference_model, reference_model=reference_model.freeze(),
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer=opt, optimizer=opt,
train_dataset=train_set, train_dataset=train_set,

View File

@ -148,7 +148,7 @@ def dpo_loss(
logits = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1) num_chosen_tokens = chosen_masks.sum(-1)