mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
more concise bpw
This commit is contained in:
parent
3b5cd401d8
commit
46109e4141
@ -249,12 +249,13 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def print_trainable_parameters(model):
|
def nparams(module):
|
||||||
def nparams(m):
|
if isinstance(module, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||||
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
return module.weight.size * 32 // module.bits
|
||||||
return m.weight.size * (32 // m.bits)
|
return sum(v.size for _, v in tree_flatten(module.parameters()))
|
||||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
|
||||||
|
|
||||||
|
|
||||||
|
def print_trainable_parameters(model):
|
||||||
leaf_modules = tree_flatten(
|
leaf_modules = tree_flatten(
|
||||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
@ -23,7 +23,7 @@ from .models import cache
|
|||||||
from .sample_utils import make_logits_processors, make_sampler
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
from .tuner.utils import load_adapters
|
from .tuner.utils import load_adapters, nparams
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
@ -100,6 +100,17 @@ def _get_classes(config: dict):
|
|||||||
return arch.Model, arch.ModelArgs
|
return arch.Model, arch.ModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bits_per_weight(model):
|
||||||
|
model_bytes = tree_reduce(
|
||||||
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||||
|
)
|
||||||
|
leaf_modules = tree_flatten(
|
||||||
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||||
|
)
|
||||||
|
model_params = sum(nparams(m) for _, m in leaf_modules)
|
||||||
|
return model_bytes * 8 / model_params
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Ensures the model is available locally. If the path does not exist locally,
|
Ensures the model is available locally. If the path does not exist locally,
|
||||||
@ -404,30 +415,6 @@ def load_config(model_path: Path) -> dict:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def compute_bits_per_weight(config, weights):
|
|
||||||
weight_count = 0
|
|
||||||
bit_count = 0
|
|
||||||
for name, param in weights.items():
|
|
||||||
bit_count += param.size * param.dtype.size * 8
|
|
||||||
if ".scales" in name or ".biases" in name:
|
|
||||||
continue
|
|
||||||
if param.dtype == mx.uint32:
|
|
||||||
base_name = name.replace(".weight", "")
|
|
||||||
quant_config = config["quantization"]
|
|
||||||
quant_layer_config = quant_config.get(base_name)
|
|
||||||
bits = (
|
|
||||||
quant_config["bits"]
|
|
||||||
if not quant_layer_config
|
|
||||||
else quant_layer_config["bits"]
|
|
||||||
)
|
|
||||||
weight_count += param.size * 32 / bits
|
|
||||||
else:
|
|
||||||
weight_count += param.size
|
|
||||||
|
|
||||||
bits_per_weight = bit_count / weight_count
|
|
||||||
return bits_per_weight
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
@ -742,7 +729,7 @@ def quantize_model(
|
|||||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
bpw = compute_bits_per_weight(quantized_config, quantized_weights)
|
bpw = compute_bits_per_weight(model)
|
||||||
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
||||||
|
|
||||||
return quantized_weights, quantized_config
|
return quantized_weights, quantized_config
|
||||||
|
Loading…
Reference in New Issue
Block a user