From b1c1e1353eff3119ce29f21fabbf00d6c92cc14f Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 11:19:59 +0100 Subject: [PATCH] nice printing the test metrics --- llms/mlx_lm/lora.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 8cceca21..53a32b91 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -314,7 +314,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set else: reference_model = model - test_loss, _, _, _ = evaluate_dpo( + test_loss, _, _, test_metrics = evaluate_dpo( model=model, ref_model=reference_model, dataset=test_set, @@ -328,7 +328,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set test_ppl = math.exp(test_loss) - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}") + print("DPO Test Metrics:") + for metric_name, metric_value in test_metrics.items(): + print(f" {metric_name}: {float(metric_value):.3f}") else: test_loss = evaluate(