fix testing

This commit is contained in:
Goekdeniz-Guelmez 2025-02-05 08:53:30 +01:00
parent 2a8e6f6e44
commit d84ad0cf86

View File

@ -342,7 +342,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
else:
reference_model = model
test_loss, test_rewards = evaluate_grpo(
test_loss, _, test_rewards = evaluate_grpo(
model=model,
ref_model=reference_model,
dataset=test_set,
@ -354,7 +354,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
group_size=args.group_size,
epsilon=args.epsilon
)
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
test_loss = evaluate(
model=model,