From 949f63f30972d105e7dc0ba49b9b3230daa230ac Mon Sep 17 00:00:00 2001
From: Anchen
Date: Thu, 21 Mar 2024 02:41:03 +1100
Subject: [PATCH] chore(mlx-lm): fix print_trainable_parameters for quant
models (#581)
* chore(mlx-lm): fix print_trainable_parameters for quant models
* chore: clean up
* refactor: use layer type to check quant bits
* chore: address comment
---
llms/mlx_lm/lora.py | 11 +++++-
llms/tests/test_lora.py | 76 +++++++++++++++++++++++++++++++++++++++--
2 files changed, 84 insertions(+), 3 deletions(-)
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index a31e973f..b89d8f0e 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -7,6 +7,7 @@ import re
import types
from pathlib import Path
+import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import yaml
@@ -143,7 +144,15 @@ def build_parser():
def print_trainable_parameters(model):
- total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
+ def nparams(m):
+ if isinstance(m, nn.QuantizedLinear):
+ return m.weight.size * (32 // m.bits)
+ return sum(v.size for _, v in tree_flatten(m.parameters()))
+
+ leaf_modules = tree_flatten(
+ model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
+ )
+ total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py
index ef3ea78e..f7666a42 100644
--- a/llms/tests/test_lora.py
+++ b/llms/tests/test_lora.py
@@ -1,13 +1,23 @@
# Copyright © 2024 Apple Inc.
+import sys
import unittest
+from io import StringIO
+from unittest.mock import MagicMock
-import mlx.core as mx
+import mlx.nn as nn
from mlx.utils import tree_flatten
-from mlx_lm import tuner, utils
+from mlx_lm import lora, tuner
+from mlx_lm.tuner.lora import LoRALinear
class TestLora(unittest.TestCase):
+ def setUp(self):
+ self.capturedOutput = StringIO()
+ sys.stdout = self.capturedOutput
+
+ def tearDown(self):
+ sys.stdout = sys.__stdout__
def test_to_lora(self):
from mlx_lm.models import llama
@@ -47,6 +57,68 @@ class TestLora(unittest.TestCase):
params["keys"] = ["self_attn.k_proj"]
check_config(params)
+ def test_quantized_print_trainable_parameters(self):
+ model = MagicMock()
+ quantized_linear = MagicMock(spec=nn.QuantizedLinear)
+ quantized_linear.weight = MagicMock(size=1e6)
+ quantized_linear.bits = 8
+ lora_linear = MagicMock(spec=LoRALinear)
+ lora_linear.weight = MagicMock(size=2e6)
+ lora_linear.parameters.return_value = [lora_linear.weight]
+
+ linear = MagicMock(spec=nn.Linear)
+ linear.weight = MagicMock(size=3e6)
+ linear.parameters.return_value = [linear.weight]
+
+ model.leaf_modules.return_value = {
+ "quantized_linear": quantized_linear,
+ "lora_linear": lora_linear,
+ "linear": linear,
+ }
+
+ model.trainable_parameters.return_value = {
+ "layer1.weight": MagicMock(size=1e6),
+ "layer3.weight": MagicMock(size=2e6),
+ }
+ expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
+ lora.print_trainable_parameters(model)
+ self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
+ self.capturedOutput.truncate(0)
+ self.capturedOutput.seek(0)
+
+ quantized_linear.weight = MagicMock(size=1e6)
+ quantized_linear.bits = 4
+ expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
+ lora.print_trainable_parameters(model)
+ self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
+ self.capturedOutput.truncate(0)
+ self.capturedOutput.seek(0)
+
+ def test_print_trainable_parameters(self):
+ model = MagicMock()
+ linear1 = MagicMock(spec=nn.Linear)
+ linear1.weight = MagicMock(size=1e6)
+ linear1.parameters.return_value = [linear1.weight]
+ linear2 = MagicMock(spec=nn.Linear)
+ linear2.weight = MagicMock(size=2e6)
+ linear2.parameters.return_value = [linear2.weight]
+ lora_linear = MagicMock(spec=LoRALinear)
+ lora_linear.weight = MagicMock(size=3e6)
+ lora_linear.parameters.return_value = [lora_linear.weight]
+ model.leaf_modules.return_value = {
+ "linear1": linear1,
+ "linear2": linear2,
+ "lora_linear": lora_linear,
+ }
+
+ model.trainable_parameters.return_value = {
+ "layer1.weight": MagicMock(size=1e6),
+ "layer3.weight": MagicMock(size=2e6),
+ }
+ expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
+ lora.print_trainable_parameters(model)
+ self.assertEqual(self.capturedOutput.getvalue(), expected_output)
+
if __name__ == "__main__":
unittest.main()