add bits per weight

This commit is contained in:
Alex Barron 2024-12-04 08:31:40 -08:00
parent 5828703a5a
commit 3b5cd401d8

View File

@ -404,6 +404,30 @@ def load_config(model_path: Path) -> dict:
return config
def compute_bits_per_weight(config, weights):
weight_count = 0
bit_count = 0
for name, param in weights.items():
bit_count += param.size * param.dtype.size * 8
if ".scales" in name or ".biases" in name:
continue
if param.dtype == mx.uint32:
base_name = name.replace(".weight", "")
quant_config = config["quantization"]
quant_layer_config = quant_config.get(base_name)
bits = (
quant_config["bits"]
if not quant_layer_config
else quant_layer_config["bits"]
)
weight_count += param.size * 32 / bits
else:
weight_count += param.size
bits_per_weight = bit_count / weight_count
return bits_per_weight
def load_model(
model_path: Path,
lazy: bool = False,
@ -705,8 +729,7 @@ def quantize_model(
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
if isinstance(bool_or_params, dict):
quantized_config["quantization"][p] = bool_or_params
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
nn.quantize(
@ -719,6 +742,9 @@ def quantize_model(
quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters()))
bpw = compute_bits_per_weight(quantized_config, quantized_weights)
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
return quantized_weights, quantized_config