Mixed Quantizations (#1132)

* saving/loading mixed quantizations

* comment

* add bits per weight

* more concise bpw

* count bias too
This commit is contained in:
Alex Barron 2024-12-08 14:21:50 -08:00 committed by GitHub
parent cd8cf28c39
commit 2211b27388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 12 deletions

View File

@ -250,12 +250,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
return model return model
def print_trainable_parameters(model): def nparams(module):
def nparams(m): if hasattr(module, "bits"):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): n = 0 if not hasattr(module, "bias") else module.bias.size
return m.weight.size * (32 // m.bits) return n + module.weight.size * 32 // module.bits
return sum(v.size for _, v in tree_flatten(m.parameters())) return sum(v.size for _, v in tree_flatten(module.parameters()))
def print_trainable_parameters(model):
leaf_modules = tree_flatten( leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
) )

View File

@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
@ -24,7 +24,7 @@ from .models import cache
from .sample_utils import make_logits_processors, make_sampler from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters from .tuner.utils import load_adapters, nparams
# Constants # Constants
MODEL_REMAPPING = { MODEL_REMAPPING = {
@ -127,6 +127,17 @@ def _get_classes(config: dict):
return arch.Model, arch.ModelArgs return arch.Model, arch.ModelArgs
def compute_bits_per_weight(model):
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
model_params = sum(nparams(m) for _, m in leaf_modules)
return model_bytes * 8 / model_params
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
""" """
Ensures the model is available locally. If the path does not exist locally, Ensures the model is available locally. If the path does not exist locally,
@ -496,15 +507,20 @@ def load_model(
weights = model.sanitize(weights) weights = model.sanitize(weights)
if (quantization := config.get("quantization", None)) is not None: if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
def class_predicate(p, m): def class_predicate(p, m):
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"): if not hasattr(m, "to_quantized"):
return False return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights return f"{p}.scales" in weights
nn.quantize( nn.quantize(
model, model,
**quantization, group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=class_predicate, class_predicate=class_predicate,
) )
@ -707,7 +723,13 @@ def save_weights(
def quantize_model( def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int model: nn.Module,
config: dict,
q_group_size: int,
q_bits: int,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
) -> Tuple: ) -> Tuple:
""" """
Applies quantization to the model weights. Applies quantization to the model weights.
@ -717,17 +739,37 @@ def quantize_model(
config (dict): Model configuration. config (dict): Model configuration.
q_group_size (int): Group size for quantization. q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization. q_bits (int): Bits per weight for quantization.
quant_predicate (Callable): A callable that decides how
to quantize each layer based on the path.
Accepts the layer `path`, the `module` and the model `config`.
Returns either a bool to signify quantize/no quantize or
a dict of quantization parameters to pass to `to_quantized`.
Returns: Returns:
Tuple: Tuple containing quantized weights and config. Tuple: Tuple containing quantized weights and config.
""" """
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
nn.quantize(
model,
q_group_size,
q_bits,
class_predicate=_class_predicate if quant_predicate else None,
)
# support hf model tree #957 # support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"] quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters())) quantized_weights = dict(tree_flatten(model.parameters()))
bpw = compute_bits_per_weight(model)
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
return quantized_weights, quantized_config return quantized_weights, quantized_config
@ -764,6 +806,9 @@ def convert(
upload_repo: str = None, upload_repo: str = None,
revision: Optional[str] = None, revision: Optional[str] = None,
dequantize: bool = False, dequantize: bool = False,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
): ):
# Check the save path is empty # Check the save path is empty
if isinstance(mlx_path, str): if isinstance(mlx_path, str):
@ -789,7 +834,9 @@ def convert(
if quantize: if quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits) weights, config = quantize_model(
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
)
if dequantize: if dequantize:
print("[INFO] Dequantizing") print("[INFO] Dequantizing")