fixing reference model loading and freezing

This commit is contained in:
Goekdeniz-Guelmez
2025-01-19 00:41:27 +01:00
parent 1ff788821c
commit 582f979dfd
2 changed files with 3 additions and 3 deletions

View File

@@ -148,7 +148,7 @@ def dpo_loss(
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)