diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 107be092..6ba81628 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -3,6 +3,7 @@ import math import sys import unittest +from contextlib import contextmanager from io import StringIO from unittest.mock import MagicMock @@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule +@contextmanager +def swapped_with_identity(obj, func): + old_func = getattr(obj, func) + setattr(obj, func, lambda x: x) + yield + setattr(obj, func, old_func) + + class TestLora(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO() @@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.6), MagicMock(return_value=120)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=2, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=2, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, @@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.2), MagicMock(return_value=150)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=-1, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=-1, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset,