Fix the test

This commit is contained in:
Angelos Katharopoulos 2024-10-31 16:28:55 -07:00
parent ece20f1d64
commit c5e09a1725

View File

@ -3,6 +3,7 @@
import math
import sys
import unittest
from contextlib import contextmanager
from io import StringIO
from unittest.mock import MagicMock
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@contextmanager
def swapped_with_identity(obj, func):
old_func = getattr(obj, func)
setattr(obj, func, lambda x: x)
yield
setattr(obj, func, old_func)
class TestLora(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
@ -374,6 +383,7 @@ class TestScheduleConfig(unittest.TestCase):
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate(
model=mock_model,
dataset=mock_dataset,
@ -412,6 +422,7 @@ class TestScheduleConfig(unittest.TestCase):
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate(
model=mock_model,
dataset=mock_dataset,