diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index ff3ff752..0775d23e 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -287,7 +287,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set model.eval() if args.training_mode == "orpo": - test_loss, test_rewards, _, _ = evaluate_orpo( + test_loss, test_rewards, _, test_metrics = evaluate_orpo( model=model, dataset=test_set, batch_size=args.batch_size, @@ -297,6 +297,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set ) test_ppl = math.exp(test_loss) print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") + + print("ORPO Test Metrics:") + for metric_name, metric_value in test_metrics.items(): + print(f" {metric_name}: {float(metric_value):.3f}") else: test_loss = evaluate( model=model,