mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Mixed Quantizations (#1132)
* saving/loading mixed quantizations * comment * add bits per weight * more concise bpw * count bias too
This commit is contained in:
@@ -250,12 +250,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
def nparams(m):
|
||||
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||
return m.weight.size * (32 // m.bits)
|
||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||
def nparams(module):
|
||||
if hasattr(module, "bits"):
|
||||
n = 0 if not hasattr(module, "bias") else module.bias.size
|
||||
return n + module.weight.size * 32 // module.bits
|
||||
return sum(v.size for _, v in tree_flatten(module.parameters()))
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
|
Reference in New Issue
Block a user