diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index a31e973f..b89d8f0e 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -7,6 +7,7 @@ import re import types from pathlib import Path +import mlx.nn as nn import mlx.optimizers as optim import numpy as np import yaml @@ -143,7 +144,15 @@ def build_parser(): def print_trainable_parameters(model): - total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 + def nparams(m): + if isinstance(m, nn.QuantizedLinear): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 trainable_p = ( sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 ) diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py index ef3ea78e..f7666a42 100644 --- a/llms/tests/test_lora.py +++ b/llms/tests/test_lora.py @@ -1,13 +1,23 @@ # Copyright © 2024 Apple Inc. +import sys import unittest +from io import StringIO +from unittest.mock import MagicMock -import mlx.core as mx +import mlx.nn as nn from mlx.utils import tree_flatten -from mlx_lm import tuner, utils +from mlx_lm import lora, tuner +from mlx_lm.tuner.lora import LoRALinear class TestLora(unittest.TestCase): + def setUp(self): + self.capturedOutput = StringIO() + sys.stdout = self.capturedOutput + + def tearDown(self): + sys.stdout = sys.__stdout__ def test_to_lora(self): from mlx_lm.models import llama @@ -47,6 +57,68 @@ class TestLora(unittest.TestCase): params["keys"] = ["self_attn.k_proj"] check_config(params) + def test_quantized_print_trainable_parameters(self): + model = MagicMock() + quantized_linear = MagicMock(spec=nn.QuantizedLinear) + quantized_linear.weight = MagicMock(size=1e6) + quantized_linear.bits = 8 + lora_linear = MagicMock(spec=LoRALinear) + lora_linear.weight = MagicMock(size=2e6) + lora_linear.parameters.return_value = [lora_linear.weight] + + linear = MagicMock(spec=nn.Linear) + linear.weight = MagicMock(size=3e6) + linear.parameters.return_value = [linear.weight] + + model.leaf_modules.return_value = { + "quantized_linear": quantized_linear, + "lora_linear": lora_linear, + "linear": linear, + } + + model.trainable_parameters.return_value = { + "layer1.weight": MagicMock(size=1e6), + "layer3.weight": MagicMock(size=2e6), + } + expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n" + lora.print_trainable_parameters(model) + self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits) + self.capturedOutput.truncate(0) + self.capturedOutput.seek(0) + + quantized_linear.weight = MagicMock(size=1e6) + quantized_linear.bits = 4 + expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n" + lora.print_trainable_parameters(model) + self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits) + self.capturedOutput.truncate(0) + self.capturedOutput.seek(0) + + def test_print_trainable_parameters(self): + model = MagicMock() + linear1 = MagicMock(spec=nn.Linear) + linear1.weight = MagicMock(size=1e6) + linear1.parameters.return_value = [linear1.weight] + linear2 = MagicMock(spec=nn.Linear) + linear2.weight = MagicMock(size=2e6) + linear2.parameters.return_value = [linear2.weight] + lora_linear = MagicMock(spec=LoRALinear) + lora_linear.weight = MagicMock(size=3e6) + lora_linear.parameters.return_value = [lora_linear.weight] + model.leaf_modules.return_value = { + "linear1": linear1, + "linear2": linear2, + "lora_linear": lora_linear, + } + + model.trainable_parameters.return_value = { + "layer1.weight": MagicMock(size=1e6), + "layer3.weight": MagicMock(size=2e6), + } + expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n" + lora.print_trainable_parameters(model) + self.assertEqual(self.capturedOutput.getvalue(), expected_output) + if __name__ == "__main__": unittest.main()