nice printing the test metrics

This commit is contained in:
Goekdeniz-Guelmez 2025-02-04 11:19:59 +01:00
parent 069431bd65
commit b1c1e1353e

View File

@ -314,7 +314,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
else:
reference_model = model
test_loss, _, _, _ = evaluate_dpo(
test_loss, _, _, test_metrics = evaluate_dpo(
model=model,
ref_model=reference_model,
dataset=test_set,
@ -328,7 +328,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}")
print("DPO Test Metrics:")
for metric_name, metric_value in test_metrics.items():
print(f" {metric_name}: {float(metric_value):.3f}")
else:
test_loss = evaluate(