From f30413b63c37de73751ac93b0dec305d26b87b20 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Sat, 4 May 2024 23:52:42 +1000
Subject: [PATCH] chore(mlx-lm): fix the number of validation batches
configuration. (#752)
* chore: fix number of validation batches
* clean up
* address comment
---
llms/mlx_lm/tuner/trainer.py | 7 +++-
llms/tests/test_lora.py | 80 ++++++++++++++++++++++++++++++++++++
2 files changed, 85 insertions(+), 2 deletions(-)
diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py
index 1408cd8f..e2b55db3 100644
--- a/llms/mlx_lm/tuner/trainer.py
+++ b/llms/mlx_lm/tuner/trainer.py
@@ -137,8 +137,11 @@ def evaluate(
):
all_losses = []
ntokens = 0
- for it, batch in zip(
- range(num_batches),
+
+ index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
+
+ for _, batch in zip(
+ index_iterator,
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,
diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py
index 61afedf4..5918c634 100644
--- a/llms/tests/test_lora.py
+++ b/llms/tests/test_lora.py
@@ -11,6 +11,7 @@ import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
from mlx_lm.tuner.lora import LoRALinear
+from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@@ -169,6 +170,85 @@ class TestScheduleConfig(unittest.TestCase):
config = {"cosine_decay": None}
self.assertRaises(KeyError, build_schedule, config)
+ def test_evaluate_calls(self):
+ mock_model = MagicMock()
+ mock_dataset = MagicMock()
+ mock_tokenizer = MagicMock()
+ mock_default_loss = MagicMock()
+ mock_iterate_batches = MagicMock()
+
+ mock_iterate_batches.return_value = [
+ (MagicMock(), MagicMock()),
+ (MagicMock(), MagicMock()),
+ (MagicMock(), MagicMock()),
+ (MagicMock(), MagicMock()),
+ (MagicMock(), MagicMock()),
+ ]
+
+ mock_default_loss.side_effect = [
+ (MagicMock(return_value=0.5), MagicMock(return_value=100)),
+ (MagicMock(return_value=0.3), MagicMock(return_value=200)),
+ (MagicMock(return_value=0.2), MagicMock(return_value=150)),
+ (MagicMock(return_value=0.4), MagicMock(return_value=180)),
+ (MagicMock(return_value=0.6), MagicMock(return_value=120)),
+ ]
+ evaluate(
+ model=mock_model,
+ dataset=mock_dataset,
+ tokenizer=mock_tokenizer,
+ batch_size=2,
+ num_batches=2,
+ max_seq_length=2048,
+ loss=mock_default_loss,
+ iterate_batches=mock_iterate_batches,
+ )
+
+ mock_iterate_batches.assert_called_once_with(
+ dataset=mock_dataset,
+ tokenizer=mock_tokenizer,
+ batch_size=2,
+ max_seq_length=2048,
+ )
+ self.assertEqual(mock_default_loss.call_count, 2)
+
+ def test_evaluate_infinite_batches(self):
+ mock_model = MagicMock()
+ mock_dataset = MagicMock()
+ mock_tokenizer = MagicMock()
+ mock_default_loss = MagicMock()
+ mock_iterate_batches = MagicMock()
+
+ mock_iterate_batches.return_value = [
+ (MagicMock(), MagicMock()),
+ (MagicMock(), MagicMock()),
+ (MagicMock(), MagicMock()),
+ ]
+
+ mock_default_loss.side_effect = [
+ (MagicMock(return_value=0.5), MagicMock(return_value=100)),
+ (MagicMock(return_value=0.3), MagicMock(return_value=200)),
+ (MagicMock(return_value=0.2), MagicMock(return_value=150)),
+ ]
+
+ evaluate(
+ model=mock_model,
+ dataset=mock_dataset,
+ tokenizer=mock_tokenizer,
+ batch_size=2,
+ num_batches=-1,
+ max_seq_length=2048,
+ loss=mock_default_loss,
+ iterate_batches=mock_iterate_batches,
+ )
+
+ mock_iterate_batches.assert_called_once_with(
+ dataset=mock_dataset,
+ tokenizer=mock_tokenizer,
+ batch_size=2,
+ max_seq_length=2048,
+ )
+ self.assertEqual(mock_default_loss.call_count, 3)
+
if __name__ == "__main__":
unittest.main()