mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
chore(mlx-lm): fix the number of validation batches configuration. (#752)
* chore: fix number of validation batches * clean up * address comment
This commit is contained in:
parent
2bf11c4633
commit
f30413b63c
@ -137,8 +137,11 @@ def evaluate(
|
|||||||
):
|
):
|
||||||
all_losses = []
|
all_losses = []
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
for it, batch in zip(
|
|
||||||
range(num_batches),
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
|
for _, batch in zip(
|
||||||
|
index_iterator,
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -11,6 +11,7 @@ import mlx.optimizers as opt
|
|||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import lora, tuner
|
from mlx_lm import lora, tuner
|
||||||
from mlx_lm.tuner.lora import LoRALinear
|
from mlx_lm.tuner.lora import LoRALinear
|
||||||
|
from mlx_lm.tuner.trainer import evaluate
|
||||||
from mlx_lm.tuner.utils import build_schedule
|
from mlx_lm.tuner.utils import build_schedule
|
||||||
|
|
||||||
|
|
||||||
@ -169,6 +170,85 @@ class TestScheduleConfig(unittest.TestCase):
|
|||||||
config = {"cosine_decay": None}
|
config = {"cosine_decay": None}
|
||||||
self.assertRaises(KeyError, build_schedule, config)
|
self.assertRaises(KeyError, build_schedule, config)
|
||||||
|
|
||||||
|
def test_evaluate_calls(self):
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_dataset = MagicMock()
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_default_loss = MagicMock()
|
||||||
|
mock_iterate_batches = MagicMock()
|
||||||
|
|
||||||
|
mock_iterate_batches.return_value = [
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_default_loss.side_effect = [
|
||||||
|
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
||||||
|
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
||||||
|
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||||
|
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||||
|
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||||
|
]
|
||||||
|
evaluate(
|
||||||
|
model=mock_model,
|
||||||
|
dataset=mock_dataset,
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
batch_size=2,
|
||||||
|
num_batches=2,
|
||||||
|
max_seq_length=2048,
|
||||||
|
loss=mock_default_loss,
|
||||||
|
iterate_batches=mock_iterate_batches,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_iterate_batches.assert_called_once_with(
|
||||||
|
dataset=mock_dataset,
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
batch_size=2,
|
||||||
|
max_seq_length=2048,
|
||||||
|
)
|
||||||
|
self.assertEqual(mock_default_loss.call_count, 2)
|
||||||
|
|
||||||
|
def test_evaluate_infinite_batches(self):
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_dataset = MagicMock()
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_default_loss = MagicMock()
|
||||||
|
mock_iterate_batches = MagicMock()
|
||||||
|
|
||||||
|
mock_iterate_batches.return_value = [
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
(MagicMock(), MagicMock()),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_default_loss.side_effect = [
|
||||||
|
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
||||||
|
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
||||||
|
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||||
|
]
|
||||||
|
|
||||||
|
evaluate(
|
||||||
|
model=mock_model,
|
||||||
|
dataset=mock_dataset,
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
batch_size=2,
|
||||||
|
num_batches=-1,
|
||||||
|
max_seq_length=2048,
|
||||||
|
loss=mock_default_loss,
|
||||||
|
iterate_batches=mock_iterate_batches,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_iterate_batches.assert_called_once_with(
|
||||||
|
dataset=mock_dataset,
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
batch_size=2,
|
||||||
|
max_seq_length=2048,
|
||||||
|
)
|
||||||
|
self.assertEqual(mock_default_loss.call_count, 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user