chore(mlx-lm): fix the number of validation batches configuration. (#752)

* chore: fix number of validation batches

* clean up

* address comment
This commit is contained in:
Anchen 2024-05-04 23:52:42 +10:00 committed by GitHub
parent 2bf11c4633
commit f30413b63c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 2 deletions

View File

@ -137,8 +137,11 @@ def evaluate(
):
all_losses = []
ntokens = 0
for it, batch in zip(
range(num_batches),
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,

View File

@ -11,6 +11,7 @@ import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@ -169,6 +170,85 @@ class TestScheduleConfig(unittest.TestCase):
config = {"cosine_decay": None}
self.assertRaises(KeyError, build_schedule, config)
def test_evaluate_calls(self):
mock_model = MagicMock()
mock_dataset = MagicMock()
mock_tokenizer = MagicMock()
mock_default_loss = MagicMock()
mock_iterate_batches = MagicMock()
mock_iterate_batches.return_value = [
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
]
mock_default_loss.side_effect = [
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
]
evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=2,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)
mock_iterate_batches.assert_called_once_with(
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
max_seq_length=2048,
)
self.assertEqual(mock_default_loss.call_count, 2)
def test_evaluate_infinite_batches(self):
mock_model = MagicMock()
mock_dataset = MagicMock()
mock_tokenizer = MagicMock()
mock_default_loss = MagicMock()
mock_iterate_batches = MagicMock()
mock_iterate_batches.return_value = [
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
]
mock_default_loss.side_effect = [
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
]
evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=-1,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)
mock_iterate_batches.assert_called_once_with(
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
max_seq_length=2048,
)
self.assertEqual(mock_default_loss.call_count, 3)
if __name__ == "__main__":
unittest.main()