fix testing

This commit is contained in:
Goekdeniz-Guelmez 2025-02-05 08:53:30 +01:00
parent 2a8e6f6e44
commit d84ad0cf86

View File

@ -342,7 +342,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
else: else:
reference_model = model reference_model = model
test_loss, test_rewards = evaluate_grpo( test_loss, _, test_rewards = evaluate_grpo(
model=model, model=model,
ref_model=reference_model, ref_model=reference_model,
dataset=test_set, dataset=test_set,
@ -354,7 +354,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
group_size=args.group_size, group_size=args.group_size,
epsilon=args.epsilon epsilon=args.epsilon
) )
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else: else:
test_loss = evaluate( test_loss = evaluate(
model=model, model=model,