From 46109e4141610cbfa64185421272f0234cf8b79e Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sun, 8 Dec 2024 12:44:29 -0800 Subject: [PATCH] more concise bpw --- llms/mlx_lm/tuner/utils.py | 11 +++++----- llms/mlx_lm/utils.py | 41 +++++++++++++------------------------- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 7c78ee91..87f31f68 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -249,12 +249,13 @@ def remove_lora_layers(model: nn.Module) -> nn.Module: return model -def print_trainable_parameters(model): - def nparams(m): - if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): - return m.weight.size * (32 // m.bits) - return sum(v.size for _, v in tree_flatten(m.parameters())) +def nparams(module): + if isinstance(module, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return module.weight.size * 32 // module.bits + return sum(v.size for _, v in tree_flatten(module.parameters())) + +def print_trainable_parameters(model): leaf_modules = tree_flatten( model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 640c1008..3044570d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten, tree_reduce +from mlx.utils import tree_flatten, tree_map, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -23,7 +23,7 @@ from .models import cache from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model -from .tuner.utils import load_adapters +from .tuner.utils import load_adapters, nparams # Constants MODEL_REMAPPING = { @@ -100,6 +100,17 @@ def _get_classes(config: dict): return arch.Model, arch.ModelArgs +def compute_bits_per_weight(model): + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + model_params = sum(nparams(m) for _, m in leaf_modules) + return model_bytes * 8 / model_params + + def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, @@ -404,30 +415,6 @@ def load_config(model_path: Path) -> dict: return config -def compute_bits_per_weight(config, weights): - weight_count = 0 - bit_count = 0 - for name, param in weights.items(): - bit_count += param.size * param.dtype.size * 8 - if ".scales" in name or ".biases" in name: - continue - if param.dtype == mx.uint32: - base_name = name.replace(".weight", "") - quant_config = config["quantization"] - quant_layer_config = quant_config.get(base_name) - bits = ( - quant_config["bits"] - if not quant_layer_config - else quant_layer_config["bits"] - ) - weight_count += param.size * 32 / bits - else: - weight_count += param.size - - bits_per_weight = bit_count / weight_count - return bits_per_weight - - def load_model( model_path: Path, lazy: bool = False, @@ -742,7 +729,7 @@ def quantize_model( quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) - bpw = compute_bits_per_weight(quantized_config, quantized_weights) + bpw = compute_bits_per_weight(model) print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") return quantized_weights, quantized_config