mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix the test
This commit is contained in:
parent
ece20f1d64
commit
c5e09a1725
@ -3,6 +3,7 @@
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
from contextlib import contextmanager
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
|
|||||||
from mlx_lm.tuner.utils import build_schedule
|
from mlx_lm.tuner.utils import build_schedule
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def swapped_with_identity(obj, func):
|
||||||
|
old_func = getattr(obj, func)
|
||||||
|
setattr(obj, func, lambda x: x)
|
||||||
|
yield
|
||||||
|
setattr(obj, func, old_func)
|
||||||
|
|
||||||
|
|
||||||
class TestLora(unittest.TestCase):
|
class TestLora(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.capturedOutput = StringIO()
|
self.capturedOutput = StringIO()
|
||||||
@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase):
|
|||||||
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||||
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||||
]
|
]
|
||||||
evaluate(
|
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||||
model=mock_model,
|
evaluate(
|
||||||
dataset=mock_dataset,
|
model=mock_model,
|
||||||
tokenizer=mock_tokenizer,
|
dataset=mock_dataset,
|
||||||
batch_size=2,
|
tokenizer=mock_tokenizer,
|
||||||
num_batches=2,
|
batch_size=2,
|
||||||
max_seq_length=2048,
|
num_batches=2,
|
||||||
loss=mock_default_loss,
|
max_seq_length=2048,
|
||||||
iterate_batches=mock_iterate_batches,
|
loss=mock_default_loss,
|
||||||
)
|
iterate_batches=mock_iterate_batches,
|
||||||
|
)
|
||||||
|
|
||||||
mock_iterate_batches.assert_called_once_with(
|
mock_iterate_batches.assert_called_once_with(
|
||||||
dataset=mock_dataset,
|
dataset=mock_dataset,
|
||||||
@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase):
|
|||||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||||
]
|
]
|
||||||
|
|
||||||
evaluate(
|
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||||
model=mock_model,
|
evaluate(
|
||||||
dataset=mock_dataset,
|
model=mock_model,
|
||||||
tokenizer=mock_tokenizer,
|
dataset=mock_dataset,
|
||||||
batch_size=2,
|
tokenizer=mock_tokenizer,
|
||||||
num_batches=-1,
|
batch_size=2,
|
||||||
max_seq_length=2048,
|
num_batches=-1,
|
||||||
loss=mock_default_loss,
|
max_seq_length=2048,
|
||||||
iterate_batches=mock_iterate_batches,
|
loss=mock_default_loss,
|
||||||
)
|
iterate_batches=mock_iterate_batches,
|
||||||
|
)
|
||||||
|
|
||||||
mock_iterate_batches.assert_called_once_with(
|
mock_iterate_batches.assert_called_once_with(
|
||||||
dataset=mock_dataset,
|
dataset=mock_dataset,
|
||||||
|
Loading…
Reference in New Issue
Block a user