diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index fd0dffdc..8cceca21 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -325,7 +325,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set delta=args.delta, loss_type=args.dpo_loss_type, ) - print(f"Test loss {test_loss:.3f}") + + test_ppl = math.exp(test_loss) + + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") else: test_loss = evaluate(