# Copyright © 2024 Apple Inc. import sys import unittest from io import StringIO from unittest.mock import MagicMock import mlx.nn as nn from mlx_lm.tuner.lora import LoRALinear from mlx_lm.tuner.utils import print_trainable_parameters class TestTunerUtils(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO() sys.stdout = self.capturedOutput def tearDown(self): sys.stdout = sys.__stdout__ def test_quantized_print_trainable_parameters(self): model = MagicMock() quantized_linear = MagicMock(spec=nn.QuantizedLinear) quantized_linear.weight = MagicMock(size=1e6) quantized_linear.bits = 8 lora_linear = MagicMock(spec=LoRALinear) lora_linear.weight = MagicMock(size=2e6) lora_linear.parameters.return_value = [lora_linear.weight] linear = MagicMock(spec=nn.Linear) linear.weight = MagicMock(size=3e6) linear.parameters.return_value = [linear.weight] model.leaf_modules.return_value = { "quantized_linear": quantized_linear, "lora_linear": lora_linear, "linear": linear, } model.trainable_parameters.return_value = { "layer1.weight": MagicMock(size=1e6), "layer3.weight": MagicMock(size=2e6), } expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n" print_trainable_parameters(model) self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits) self.capturedOutput.truncate(0) self.capturedOutput.seek(0) quantized_linear.weight = MagicMock(size=1e6) quantized_linear.bits = 4 expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n" print_trainable_parameters(model) self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits) self.capturedOutput.truncate(0) self.capturedOutput.seek(0) def test_print_trainable_parameters(self): model = MagicMock() linear1 = MagicMock(spec=nn.Linear) linear1.weight = MagicMock(size=1e6) linear1.parameters.return_value = [linear1.weight] linear2 = MagicMock(spec=nn.Linear) linear2.weight = MagicMock(size=2e6) linear2.parameters.return_value = [linear2.weight] lora_linear = MagicMock(spec=LoRALinear) lora_linear.weight = MagicMock(size=3e6) lora_linear.parameters.return_value = [lora_linear.weight] model.leaf_modules.return_value = { "linear1": linear1, "linear2": linear2, "lora_linear": lora_linear, } model.trainable_parameters.return_value = { "layer1.weight": MagicMock(size=1e6), "layer3.weight": MagicMock(size=2e6), } expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n" print_trainable_parameters(model) self.assertEqual(self.capturedOutput.getvalue(), expected_output) if __name__ == "__main__": unittest.main()