diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 1408cd8f..e2b55db3 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -137,8 +137,11 @@ def evaluate( ): all_losses = [] ntokens = 0 - for it, batch in zip( - range(num_batches), + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, iterate_batches( dataset=dataset, tokenizer=tokenizer, diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py index 61afedf4..5918c634 100644 --- a/llms/tests/test_lora.py +++ b/llms/tests/test_lora.py @@ -11,6 +11,7 @@ import mlx.optimizers as opt from mlx.utils import tree_flatten from mlx_lm import lora, tuner from mlx_lm.tuner.lora import LoRALinear +from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule @@ -169,6 +170,85 @@ class TestScheduleConfig(unittest.TestCase): config = {"cosine_decay": None} self.assertRaises(KeyError, build_schedule, config) + def test_evaluate_calls(self): + mock_model = MagicMock() + mock_dataset = MagicMock() + mock_tokenizer = MagicMock() + mock_default_loss = MagicMock() + mock_iterate_batches = MagicMock() + + mock_iterate_batches.return_value = [ + (MagicMock(), MagicMock()), + (MagicMock(), MagicMock()), + (MagicMock(), MagicMock()), + (MagicMock(), MagicMock()), + (MagicMock(), MagicMock()), + ] + + mock_default_loss.side_effect = [ + (MagicMock(return_value=0.5), MagicMock(return_value=100)), + (MagicMock(return_value=0.3), MagicMock(return_value=200)), + (MagicMock(return_value=0.2), MagicMock(return_value=150)), + (MagicMock(return_value=0.4), MagicMock(return_value=180)), + (MagicMock(return_value=0.6), MagicMock(return_value=120)), + ] + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=2, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) + + mock_iterate_batches.assert_called_once_with( + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + max_seq_length=2048, + ) + self.assertEqual(mock_default_loss.call_count, 2) + + def test_evaluate_infinite_batches(self): + mock_model = MagicMock() + mock_dataset = MagicMock() + mock_tokenizer = MagicMock() + mock_default_loss = MagicMock() + mock_iterate_batches = MagicMock() + + mock_iterate_batches.return_value = [ + (MagicMock(), MagicMock()), + (MagicMock(), MagicMock()), + (MagicMock(), MagicMock()), + ] + + mock_default_loss.side_effect = [ + (MagicMock(return_value=0.5), MagicMock(return_value=100)), + (MagicMock(return_value=0.3), MagicMock(return_value=200)), + (MagicMock(return_value=0.2), MagicMock(return_value=150)), + ] + + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=-1, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) + + mock_iterate_batches.assert_called_once_with( + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + max_seq_length=2048, + ) + self.assertEqual(mock_default_loss.call_count, 3) + if __name__ == "__main__": unittest.main()