mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Mixed Quantizations (#1132)
* saving/loading mixed quantizations * comment * add bits per weight * more concise bpw * count bias too
This commit is contained in:
parent
cd8cf28c39
commit
2211b27388
@ -250,12 +250,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
def nparams(m):
|
||||
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
||||
return m.weight.size * (32 // m.bits)
|
||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||
def nparams(module):
|
||||
if hasattr(module, "bits"):
|
||||
n = 0 if not hasattr(module, "bias") else module.bias.size
|
||||
return n + module.weight.size * 32 // module.bits
|
||||
return sum(v.size for _, v in tree_flatten(module.parameters()))
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
|
@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten, tree_reduce
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
@ -24,7 +24,7 @@ from .models import cache
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
from .tuner.utils import load_adapters
|
||||
from .tuner.utils import load_adapters, nparams
|
||||
|
||||
# Constants
|
||||
MODEL_REMAPPING = {
|
||||
@ -127,6 +127,17 @@ def _get_classes(config: dict):
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def compute_bits_per_weight(model):
|
||||
model_bytes = tree_reduce(
|
||||
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||
)
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
model_params = sum(nparams(m) for _, m in leaf_modules)
|
||||
return model_bytes * 8 / model_params
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensures the model is available locally. If the path does not exist locally,
|
||||
@ -496,15 +507,20 @@ def load_model(
|
||||
weights = model.sanitize(weights)
|
||||
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
|
||||
def class_predicate(p, m):
|
||||
# Handle custom per layer quantizations
|
||||
if p in config["quantization"]:
|
||||
return config["quantization"][p]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Handle legacy models which may not have everything quantized
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
@ -707,7 +723,13 @@ def save_weights(
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||
model: nn.Module,
|
||||
config: dict,
|
||||
q_group_size: int,
|
||||
q_bits: int,
|
||||
quant_predicate: Optional[
|
||||
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||
] = None,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
@ -717,17 +739,37 @@ def quantize_model(
|
||||
config (dict): Model configuration.
|
||||
q_group_size (int): Group size for quantization.
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
quant_predicate (Callable): A callable that decides how
|
||||
to quantize each layer based on the path.
|
||||
Accepts the layer `path`, the `module` and the model `config`.
|
||||
Returns either a bool to signify quantize/no quantize or
|
||||
a dict of quantization parameters to pass to `to_quantized`.
|
||||
|
||||
Returns:
|
||||
Tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
nn.quantize(model, q_group_size, q_bits)
|
||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||
|
||||
# Add any custom quantization parameters to the config as we go
|
||||
def _class_predicate(p, m):
|
||||
bool_or_params = quant_predicate(p, m, config)
|
||||
quantized_config["quantization"][p] = bool_or_params
|
||||
return bool_or_params
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
q_group_size,
|
||||
q_bits,
|
||||
class_predicate=_class_predicate if quant_predicate else None,
|
||||
)
|
||||
# support hf model tree #957
|
||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
bpw = compute_bits_per_weight(model)
|
||||
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
@ -764,6 +806,9 @@ def convert(
|
||||
upload_repo: str = None,
|
||||
revision: Optional[str] = None,
|
||||
dequantize: bool = False,
|
||||
quant_predicate: Optional[
|
||||
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||
] = None,
|
||||
):
|
||||
# Check the save path is empty
|
||||
if isinstance(mlx_path, str):
|
||||
@ -789,7 +834,9 @@ def convert(
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
weights, config = quantize_model(
|
||||
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
||||
)
|
||||
|
||||
if dequantize:
|
||||
print("[INFO] Dequantizing")
|
||||
|
Loading…
Reference in New Issue
Block a user