chore(mlx-lm): fix print_trainable_parameters for quant models (#581)

* chore(mlx-lm): fix print_trainable_parameters for quant models

* chore: clean up

* refactor: use layer type to check quant bits

* chore: address comment
This commit is contained in:
Anchen 2024-03-21 02:41:03 +11:00 committed by GitHub
parent 373dd6f2a2
commit 949f63f309
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 3 deletions

View File

@ -7,6 +7,7 @@ import re
import types import types
from pathlib import Path from pathlib import Path
import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
import yaml import yaml
@ -143,7 +144,15 @@ def build_parser():
def print_trainable_parameters(model): def print_trainable_parameters(model):
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 def nparams(m):
if isinstance(m, nn.QuantizedLinear):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = ( trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
) )

View File

@ -1,13 +1,23 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import sys
import unittest import unittest
from io import StringIO
from unittest.mock import MagicMock
import mlx.core as mx import mlx.nn as nn
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from mlx_lm import tuner, utils from mlx_lm import lora, tuner
from mlx_lm.tuner.lora import LoRALinear
class TestLora(unittest.TestCase): class TestLora(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
sys.stdout = self.capturedOutput
def tearDown(self):
sys.stdout = sys.__stdout__
def test_to_lora(self): def test_to_lora(self):
from mlx_lm.models import llama from mlx_lm.models import llama
@ -47,6 +57,68 @@ class TestLora(unittest.TestCase):
params["keys"] = ["self_attn.k_proj"] params["keys"] = ["self_attn.k_proj"]
check_config(params) check_config(params)
def test_quantized_print_trainable_parameters(self):
model = MagicMock()
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 8
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=2e6)
lora_linear.parameters.return_value = [lora_linear.weight]
linear = MagicMock(spec=nn.Linear)
linear.weight = MagicMock(size=3e6)
linear.parameters.return_value = [linear.weight]
model.leaf_modules.return_value = {
"quantized_linear": quantized_linear,
"lora_linear": lora_linear,
"linear": linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
quantized_linear.weight = MagicMock(size=1e6)
quantized_linear.bits = 4
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
self.capturedOutput.truncate(0)
self.capturedOutput.seek(0)
def test_print_trainable_parameters(self):
model = MagicMock()
linear1 = MagicMock(spec=nn.Linear)
linear1.weight = MagicMock(size=1e6)
linear1.parameters.return_value = [linear1.weight]
linear2 = MagicMock(spec=nn.Linear)
linear2.weight = MagicMock(size=2e6)
linear2.parameters.return_value = [linear2.weight]
lora_linear = MagicMock(spec=LoRALinear)
lora_linear.weight = MagicMock(size=3e6)
lora_linear.parameters.return_value = [lora_linear.weight]
model.leaf_modules.return_value = {
"linear1": linear1,
"linear2": linear2,
"lora_linear": lora_linear,
}
model.trainable_parameters.return_value = {
"layer1.weight": MagicMock(size=1e6),
"layer3.weight": MagicMock(size=2e6),
}
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
lora.print_trainable_parameters(model)
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()