diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 6ab3edb7..640c1008 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -404,6 +404,30 @@ def load_config(model_path: Path) -> dict: return config +def compute_bits_per_weight(config, weights): + weight_count = 0 + bit_count = 0 + for name, param in weights.items(): + bit_count += param.size * param.dtype.size * 8 + if ".scales" in name or ".biases" in name: + continue + if param.dtype == mx.uint32: + base_name = name.replace(".weight", "") + quant_config = config["quantization"] + quant_layer_config = quant_config.get(base_name) + bits = ( + quant_config["bits"] + if not quant_layer_config + else quant_layer_config["bits"] + ) + weight_count += param.size * 32 / bits + else: + weight_count += param.size + + bits_per_weight = bit_count / weight_count + return bits_per_weight + + def load_model( model_path: Path, lazy: bool = False, @@ -705,8 +729,7 @@ def quantize_model( # Add any custom quantization parameters to the config as we go def _class_predicate(p, m): bool_or_params = quant_predicate(p, m, config) - if isinstance(bool_or_params, dict): - quantized_config["quantization"][p] = bool_or_params + quantized_config["quantization"][p] = bool_or_params return bool_or_params nn.quantize( @@ -719,6 +742,9 @@ def quantize_model( quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) + bpw = compute_bits_per_weight(quantized_config, quantized_weights) + print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") + return quantized_weights, quantized_config