mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix the test
This commit is contained in:
parent
ece20f1d64
commit
c5e09a1725
@ -3,6 +3,7 @@
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
from contextlib import contextmanager
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
|
|||||||
from mlx_lm.tuner.utils import build_schedule
|
from mlx_lm.tuner.utils import build_schedule
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def swapped_with_identity(obj, func):
|
||||||
|
old_func = getattr(obj, func)
|
||||||
|
setattr(obj, func, lambda x: x)
|
||||||
|
yield
|
||||||
|
setattr(obj, func, old_func)
|
||||||
|
|
||||||
|
|
||||||
class TestLora(unittest.TestCase):
|
class TestLora(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.capturedOutput = StringIO()
|
self.capturedOutput = StringIO()
|
||||||
@ -374,6 +383,7 @@ class TestScheduleConfig(unittest.TestCase):
|
|||||||
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||||
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||||
]
|
]
|
||||||
|
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||||
evaluate(
|
evaluate(
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
dataset=mock_dataset,
|
dataset=mock_dataset,
|
||||||
@ -412,6 +422,7 @@ class TestScheduleConfig(unittest.TestCase):
|
|||||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||||
evaluate(
|
evaluate(
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
dataset=mock_dataset,
|
dataset=mock_dataset,
|
||||||
|
Loading…
Reference in New Issue
Block a user