Mixed Quantizations (#1132)

* saving/loading mixed quantizations

* comment

* add bits per weight

* more concise bpw

* count bias too
This commit is contained in:
Alex Barron
2024-12-08 14:21:50 -08:00
committed by GitHub
parent cd8cf28c39
commit 2211b27388
2 changed files with 61 additions and 12 deletions

View File

@@ -250,12 +250,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
return model
def print_trainable_parameters(model):
def nparams(m):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
def nparams(module):
if hasattr(module, "bits"):
n = 0 if not hasattr(module, "bias") else module.bias.size
return n + module.weight.size * 32 // module.bits
return sum(v.size for _, v in tree_flatten(module.parameters()))
def print_trainable_parameters(model):
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)