mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
add bits per weight
This commit is contained in:
parent
5828703a5a
commit
3b5cd401d8
@ -404,6 +404,30 @@ def load_config(model_path: Path) -> dict:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bits_per_weight(config, weights):
|
||||||
|
weight_count = 0
|
||||||
|
bit_count = 0
|
||||||
|
for name, param in weights.items():
|
||||||
|
bit_count += param.size * param.dtype.size * 8
|
||||||
|
if ".scales" in name or ".biases" in name:
|
||||||
|
continue
|
||||||
|
if param.dtype == mx.uint32:
|
||||||
|
base_name = name.replace(".weight", "")
|
||||||
|
quant_config = config["quantization"]
|
||||||
|
quant_layer_config = quant_config.get(base_name)
|
||||||
|
bits = (
|
||||||
|
quant_config["bits"]
|
||||||
|
if not quant_layer_config
|
||||||
|
else quant_layer_config["bits"]
|
||||||
|
)
|
||||||
|
weight_count += param.size * 32 / bits
|
||||||
|
else:
|
||||||
|
weight_count += param.size
|
||||||
|
|
||||||
|
bits_per_weight = bit_count / weight_count
|
||||||
|
return bits_per_weight
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
@ -705,8 +729,7 @@ def quantize_model(
|
|||||||
# Add any custom quantization parameters to the config as we go
|
# Add any custom quantization parameters to the config as we go
|
||||||
def _class_predicate(p, m):
|
def _class_predicate(p, m):
|
||||||
bool_or_params = quant_predicate(p, m, config)
|
bool_or_params = quant_predicate(p, m, config)
|
||||||
if isinstance(bool_or_params, dict):
|
quantized_config["quantization"][p] = bool_or_params
|
||||||
quantized_config["quantization"][p] = bool_or_params
|
|
||||||
return bool_or_params
|
return bool_or_params
|
||||||
|
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
@ -719,6 +742,9 @@ def quantize_model(
|
|||||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
bpw = compute_bits_per_weight(quantized_config, quantized_weights)
|
||||||
|
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
||||||
|
|
||||||
return quantized_weights, quantized_config
|
return quantized_weights, quantized_config
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user