more concise bpw

This commit is contained in:
Alex Barron 2024-12-08 12:44:29 -08:00
parent 3b5cd401d8
commit 46109e4141
2 changed files with 20 additions and 32 deletions

View File

@ -249,12 +249,13 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
return model return model
def print_trainable_parameters(model): def nparams(module):
def nparams(m): if isinstance(module, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): return module.weight.size * 32 // module.bits
return m.weight.size * (32 // m.bits) return sum(v.size for _, v in tree_flatten(module.parameters()))
return sum(v.size for _, v in tree_flatten(m.parameters()))
def print_trainable_parameters(model):
leaf_modules = tree_flatten( leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
) )

View File

@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
@ -23,7 +23,7 @@ from .models import cache
from .sample_utils import make_logits_processors, make_sampler from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters from .tuner.utils import load_adapters, nparams
# Constants # Constants
MODEL_REMAPPING = { MODEL_REMAPPING = {
@ -100,6 +100,17 @@ def _get_classes(config: dict):
return arch.Model, arch.ModelArgs return arch.Model, arch.ModelArgs
def compute_bits_per_weight(model):
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
model_params = sum(nparams(m) for _, m in leaf_modules)
return model_bytes * 8 / model_params
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
""" """
Ensures the model is available locally. If the path does not exist locally, Ensures the model is available locally. If the path does not exist locally,
@ -404,30 +415,6 @@ def load_config(model_path: Path) -> dict:
return config return config
def compute_bits_per_weight(config, weights):
weight_count = 0
bit_count = 0
for name, param in weights.items():
bit_count += param.size * param.dtype.size * 8
if ".scales" in name or ".biases" in name:
continue
if param.dtype == mx.uint32:
base_name = name.replace(".weight", "")
quant_config = config["quantization"]
quant_layer_config = quant_config.get(base_name)
bits = (
quant_config["bits"]
if not quant_layer_config
else quant_layer_config["bits"]
)
weight_count += param.size * 32 / bits
else:
weight_count += param.size
bits_per_weight = bit_count / weight_count
return bits_per_weight
def load_model( def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
@ -742,7 +729,7 @@ def quantize_model(
quantized_config["quantization_config"] = quantized_config["quantization"] quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters())) quantized_weights = dict(tree_flatten(model.parameters()))
bpw = compute_bits_per_weight(quantized_config, quantized_weights) bpw = compute_bits_per_weight(model)
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
return quantized_weights, quantized_config return quantized_weights, quantized_config