From 56712664f6bfb5a56299938eadb356f73194558b Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 11:21:52 +0100 Subject: [PATCH] nice metric printing in testing --- llms/mlx_lm/lora.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index ff3ff752..0775d23e 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -287,7 +287,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set model.eval() if args.training_mode == "orpo": - test_loss, test_rewards, _, _ = evaluate_orpo( + test_loss, test_rewards, _, test_metrics = evaluate_orpo( model=model, dataset=test_set, batch_size=args.batch_size, @@ -297,6 +297,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set ) test_ppl = math.exp(test_loss) print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") + + print("ORPO Test Metrics:") + for metric_name, metric_value in test_metrics.items(): + print(f" {metric_name}: {float(metric_value):.3f}") else: test_loss = evaluate( model=model,