fixing reference model loading and freezing

This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 00:41:27 +01:00
parent 1ff788821c
commit 582f979dfd
2 changed files with 3 additions and 3 deletions

View File

@ -249,11 +249,11 @@ def train_model(
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model
reference_model, _ = load(args.model)
train_dpo(
model=model,
reference_model=reference_model,
reference_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,

View File

@ -148,7 +148,7 @@ def dpo_loss(
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)