mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
chore(mlx-lm): fix the number of validation batches configuration. (#752)
* chore: fix number of validation batches * clean up * address comment
This commit is contained in:
parent
2bf11c4633
commit
f30413b63c
@ -137,8 +137,11 @@ def evaluate(
|
||||
):
|
||||
all_losses = []
|
||||
ntokens = 0
|
||||
for it, batch in zip(
|
||||
range(num_batches),
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
for _, batch in zip(
|
||||
index_iterator,
|
||||
iterate_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -11,6 +11,7 @@ import mlx.optimizers as opt
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import lora, tuner
|
||||
from mlx_lm.tuner.lora import LoRALinear
|
||||
from mlx_lm.tuner.trainer import evaluate
|
||||
from mlx_lm.tuner.utils import build_schedule
|
||||
|
||||
|
||||
@ -169,6 +170,85 @@ class TestScheduleConfig(unittest.TestCase):
|
||||
config = {"cosine_decay": None}
|
||||
self.assertRaises(KeyError, build_schedule, config)
|
||||
|
||||
def test_evaluate_calls(self):
|
||||
mock_model = MagicMock()
|
||||
mock_dataset = MagicMock()
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_default_loss = MagicMock()
|
||||
mock_iterate_batches = MagicMock()
|
||||
|
||||
mock_iterate_batches.return_value = [
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
]
|
||||
|
||||
mock_default_loss.side_effect = [
|
||||
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
||||
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||
]
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
num_batches=2,
|
||||
max_seq_length=2048,
|
||||
loss=mock_default_loss,
|
||||
iterate_batches=mock_iterate_batches,
|
||||
)
|
||||
|
||||
mock_iterate_batches.assert_called_once_with(
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
max_seq_length=2048,
|
||||
)
|
||||
self.assertEqual(mock_default_loss.call_count, 2)
|
||||
|
||||
def test_evaluate_infinite_batches(self):
|
||||
mock_model = MagicMock()
|
||||
mock_dataset = MagicMock()
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_default_loss = MagicMock()
|
||||
mock_iterate_batches = MagicMock()
|
||||
|
||||
mock_iterate_batches.return_value = [
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
]
|
||||
|
||||
mock_default_loss.side_effect = [
|
||||
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
||||
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||
]
|
||||
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
num_batches=-1,
|
||||
max_seq_length=2048,
|
||||
loss=mock_default_loss,
|
||||
iterate_batches=mock_iterate_batches,
|
||||
)
|
||||
|
||||
mock_iterate_batches.assert_called_once_with(
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
max_seq_length=2048,
|
||||
)
|
||||
self.assertEqual(mock_default_loss.call_count, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user