From c5e09a17253d1ebcf4a2a5482aa363b05ff2af31 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 31 Oct 2024 16:28:55 -0700 Subject: [PATCH] Fix the test --- llms/tests/test_finetune.py | 51 ++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 107be092..6ba81628 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -3,6 +3,7 @@ import math import sys import unittest +from contextlib import contextmanager from io import StringIO from unittest.mock import MagicMock @@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule +@contextmanager +def swapped_with_identity(obj, func): + old_func = getattr(obj, func) + setattr(obj, func, lambda x: x) + yield + setattr(obj, func, old_func) + + class TestLora(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO() @@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.6), MagicMock(return_value=120)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=2, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=2, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, @@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.2), MagicMock(return_value=150)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=-1, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=-1, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset,