adding test_ppl in testing

This commit is contained in:
Goekdeniz-Guelmez 2025-02-04 11:18:09 +01:00
parent 43f2451973
commit 069431bd65

View File

@ -325,7 +325,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
delta=args.delta,
loss_type=args.dpo_loss_type,
)
print(f"Test loss {test_loss:.3f}")
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
else:
test_loss = evaluate(