diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 87f31f68..109d7eee 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -250,8 +250,9 @@ def remove_lora_layers(model: nn.Module) -> nn.Module: def nparams(module): - if isinstance(module, (nn.QuantizedLinear, nn.QuantizedEmbedding)): - return module.weight.size * 32 // module.bits + if hasattr(module, "bits"): + n = 0 if not hasattr(module, "bias") else module.bias.size + return n + module.weight.size * 32 // module.bits return sum(v.size for _, v in tree_flatten(module.parameters()))