mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
more concise bpw
This commit is contained in:
parent
3b5cd401d8
commit
46109e4141
@ -249,12 +249,13 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
def nparams(m):
|
||||
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||
return m.weight.size * (32 // m.bits)
|
||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||
def nparams(module):
|
||||
if isinstance(module, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||
return module.weight.size * 32 // module.bits
|
||||
return sum(v.size for _, v in tree_flatten(module.parameters()))
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
|
@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten, tree_reduce
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
@ -23,7 +23,7 @@ from .models import cache
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
from .tuner.utils import load_adapters
|
||||
from .tuner.utils import load_adapters, nparams
|
||||
|
||||
# Constants
|
||||
MODEL_REMAPPING = {
|
||||
@ -100,6 +100,17 @@ def _get_classes(config: dict):
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def compute_bits_per_weight(model):
|
||||
model_bytes = tree_reduce(
|
||||
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||
)
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
model_params = sum(nparams(m) for _, m in leaf_modules)
|
||||
return model_bytes * 8 / model_params
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensures the model is available locally. If the path does not exist locally,
|
||||
@ -404,30 +415,6 @@ def load_config(model_path: Path) -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def compute_bits_per_weight(config, weights):
|
||||
weight_count = 0
|
||||
bit_count = 0
|
||||
for name, param in weights.items():
|
||||
bit_count += param.size * param.dtype.size * 8
|
||||
if ".scales" in name or ".biases" in name:
|
||||
continue
|
||||
if param.dtype == mx.uint32:
|
||||
base_name = name.replace(".weight", "")
|
||||
quant_config = config["quantization"]
|
||||
quant_layer_config = quant_config.get(base_name)
|
||||
bits = (
|
||||
quant_config["bits"]
|
||||
if not quant_layer_config
|
||||
else quant_layer_config["bits"]
|
||||
)
|
||||
weight_count += param.size * 32 / bits
|
||||
else:
|
||||
weight_count += param.size
|
||||
|
||||
bits_per_weight = bit_count / weight_count
|
||||
return bits_per_weight
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: Path,
|
||||
lazy: bool = False,
|
||||
@ -742,7 +729,7 @@ def quantize_model(
|
||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
bpw = compute_bits_per_weight(quantized_config, quantized_weights)
|
||||
bpw = compute_bits_per_weight(model)
|
||||
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
Loading…
Reference in New Issue
Block a user