mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
chore(mlx-lm): fix print_trainable_parameters for quant models (#581)
* chore(mlx-lm): fix print_trainable_parameters for quant models * chore: clean up * refactor: use layer type to check quant bits * chore: address comment
This commit is contained in:
parent
373dd6f2a2
commit
949f63f309
@ -7,6 +7,7 @@ import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
@ -143,7 +144,15 @@ def build_parser():
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
def nparams(m):
|
||||
if isinstance(m, nn.QuantizedLinear):
|
||||
return m.weight.size * (32 // m.bits)
|
||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
||||
trainable_p = (
|
||||
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
)
|
||||
|
@ -1,13 +1,23 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import tuner, utils
|
||||
from mlx_lm import lora, tuner
|
||||
from mlx_lm.tuner.lora import LoRALinear
|
||||
|
||||
|
||||
class TestLora(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.capturedOutput = StringIO()
|
||||
sys.stdout = self.capturedOutput
|
||||
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
def test_to_lora(self):
|
||||
from mlx_lm.models import llama
|
||||
@ -47,6 +57,68 @@ class TestLora(unittest.TestCase):
|
||||
params["keys"] = ["self_attn.k_proj"]
|
||||
check_config(params)
|
||||
|
||||
def test_quantized_print_trainable_parameters(self):
|
||||
model = MagicMock()
|
||||
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
|
||||
quantized_linear.weight = MagicMock(size=1e6)
|
||||
quantized_linear.bits = 8
|
||||
lora_linear = MagicMock(spec=LoRALinear)
|
||||
lora_linear.weight = MagicMock(size=2e6)
|
||||
lora_linear.parameters.return_value = [lora_linear.weight]
|
||||
|
||||
linear = MagicMock(spec=nn.Linear)
|
||||
linear.weight = MagicMock(size=3e6)
|
||||
linear.parameters.return_value = [linear.weight]
|
||||
|
||||
model.leaf_modules.return_value = {
|
||||
"quantized_linear": quantized_linear,
|
||||
"lora_linear": lora_linear,
|
||||
"linear": linear,
|
||||
}
|
||||
|
||||
model.trainable_parameters.return_value = {
|
||||
"layer1.weight": MagicMock(size=1e6),
|
||||
"layer3.weight": MagicMock(size=2e6),
|
||||
}
|
||||
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
|
||||
lora.print_trainable_parameters(model)
|
||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
|
||||
self.capturedOutput.truncate(0)
|
||||
self.capturedOutput.seek(0)
|
||||
|
||||
quantized_linear.weight = MagicMock(size=1e6)
|
||||
quantized_linear.bits = 4
|
||||
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
|
||||
lora.print_trainable_parameters(model)
|
||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
|
||||
self.capturedOutput.truncate(0)
|
||||
self.capturedOutput.seek(0)
|
||||
|
||||
def test_print_trainable_parameters(self):
|
||||
model = MagicMock()
|
||||
linear1 = MagicMock(spec=nn.Linear)
|
||||
linear1.weight = MagicMock(size=1e6)
|
||||
linear1.parameters.return_value = [linear1.weight]
|
||||
linear2 = MagicMock(spec=nn.Linear)
|
||||
linear2.weight = MagicMock(size=2e6)
|
||||
linear2.parameters.return_value = [linear2.weight]
|
||||
lora_linear = MagicMock(spec=LoRALinear)
|
||||
lora_linear.weight = MagicMock(size=3e6)
|
||||
lora_linear.parameters.return_value = [lora_linear.weight]
|
||||
model.leaf_modules.return_value = {
|
||||
"linear1": linear1,
|
||||
"linear2": linear2,
|
||||
"lora_linear": lora_linear,
|
||||
}
|
||||
|
||||
model.trainable_parameters.return_value = {
|
||||
"layer1.weight": MagicMock(size=1e6),
|
||||
"layer3.weight": MagicMock(size=2e6),
|
||||
}
|
||||
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
|
||||
lora.print_trainable_parameters(model)
|
||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user