mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 23:49:43 +08:00
add bits per weight
This commit is contained in:
parent
5828703a5a
commit
3b5cd401d8
@ -404,6 +404,30 @@ def load_config(model_path: Path) -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def compute_bits_per_weight(config, weights):
|
||||
weight_count = 0
|
||||
bit_count = 0
|
||||
for name, param in weights.items():
|
||||
bit_count += param.size * param.dtype.size * 8
|
||||
if ".scales" in name or ".biases" in name:
|
||||
continue
|
||||
if param.dtype == mx.uint32:
|
||||
base_name = name.replace(".weight", "")
|
||||
quant_config = config["quantization"]
|
||||
quant_layer_config = quant_config.get(base_name)
|
||||
bits = (
|
||||
quant_config["bits"]
|
||||
if not quant_layer_config
|
||||
else quant_layer_config["bits"]
|
||||
)
|
||||
weight_count += param.size * 32 / bits
|
||||
else:
|
||||
weight_count += param.size
|
||||
|
||||
bits_per_weight = bit_count / weight_count
|
||||
return bits_per_weight
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: Path,
|
||||
lazy: bool = False,
|
||||
@ -705,8 +729,7 @@ def quantize_model(
|
||||
# Add any custom quantization parameters to the config as we go
|
||||
def _class_predicate(p, m):
|
||||
bool_or_params = quant_predicate(p, m, config)
|
||||
if isinstance(bool_or_params, dict):
|
||||
quantized_config["quantization"][p] = bool_or_params
|
||||
quantized_config["quantization"][p] = bool_or_params
|
||||
return bool_or_params
|
||||
|
||||
nn.quantize(
|
||||
@ -719,6 +742,9 @@ def quantize_model(
|
||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
bpw = compute_bits_per_weight(quantized_config, quantized_weights)
|
||||
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user