mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix the test
This commit is contained in:
parent
ece20f1d64
commit
c5e09a1725
@ -3,6 +3,7 @@
|
||||
import math
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
|
||||
from mlx_lm.tuner.utils import build_schedule
|
||||
|
||||
|
||||
@contextmanager
|
||||
def swapped_with_identity(obj, func):
|
||||
old_func = getattr(obj, func)
|
||||
setattr(obj, func, lambda x: x)
|
||||
yield
|
||||
setattr(obj, func, old_func)
|
||||
|
||||
|
||||
class TestLora(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.capturedOutput = StringIO()
|
||||
@ -374,6 +383,7 @@ class TestScheduleConfig(unittest.TestCase):
|
||||
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||
]
|
||||
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
@ -412,6 +422,7 @@ class TestScheduleConfig(unittest.TestCase):
|
||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||
]
|
||||
|
||||
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
|
Loading…
Reference in New Issue
Block a user