mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Mixed Quantizations (#1132)
* saving/loading mixed quantizations * comment * add bits per weight * more concise bpw * count bias too
This commit is contained in:
parent
cd8cf28c39
commit
2211b27388
@ -250,12 +250,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def print_trainable_parameters(model):
|
def nparams(module):
|
||||||
def nparams(m):
|
if hasattr(module, "bits"):
|
||||||
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
|
n = 0 if not hasattr(module, "bias") else module.bias.size
|
||||||
return m.weight.size * (32 // m.bits)
|
return n + module.weight.size * 32 // module.bits
|
||||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
return sum(v.size for _, v in tree_flatten(module.parameters()))
|
||||||
|
|
||||||
|
|
||||||
|
def print_trainable_parameters(model):
|
||||||
leaf_modules = tree_flatten(
|
leaf_modules = tree_flatten(
|
||||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||||
)
|
)
|
||||||
|
@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
@ -24,7 +24,7 @@ from .models import cache
|
|||||||
from .sample_utils import make_logits_processors, make_sampler
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
from .tuner.utils import load_adapters
|
from .tuner.utils import load_adapters, nparams
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
@ -127,6 +127,17 @@ def _get_classes(config: dict):
|
|||||||
return arch.Model, arch.ModelArgs
|
return arch.Model, arch.ModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bits_per_weight(model):
|
||||||
|
model_bytes = tree_reduce(
|
||||||
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||||
|
)
|
||||||
|
leaf_modules = tree_flatten(
|
||||||
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||||
|
)
|
||||||
|
model_params = sum(nparams(m) for _, m in leaf_modules)
|
||||||
|
return model_bytes * 8 / model_params
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Ensures the model is available locally. If the path does not exist locally,
|
Ensures the model is available locally. If the path does not exist locally,
|
||||||
@ -496,15 +507,20 @@ def load_model(
|
|||||||
weights = model.sanitize(weights)
|
weights = model.sanitize(weights)
|
||||||
|
|
||||||
if (quantization := config.get("quantization", None)) is not None:
|
if (quantization := config.get("quantization", None)) is not None:
|
||||||
# Handle legacy models which may not have everything quantized
|
|
||||||
def class_predicate(p, m):
|
def class_predicate(p, m):
|
||||||
|
# Handle custom per layer quantizations
|
||||||
|
if p in config["quantization"]:
|
||||||
|
return config["quantization"][p]
|
||||||
if not hasattr(m, "to_quantized"):
|
if not hasattr(m, "to_quantized"):
|
||||||
return False
|
return False
|
||||||
|
# Handle legacy models which may not have everything quantized
|
||||||
return f"{p}.scales" in weights
|
return f"{p}.scales" in weights
|
||||||
|
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
model,
|
model,
|
||||||
**quantization,
|
group_size=quantization["group_size"],
|
||||||
|
bits=quantization["bits"],
|
||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -707,7 +723,13 @@ def save_weights(
|
|||||||
|
|
||||||
|
|
||||||
def quantize_model(
|
def quantize_model(
|
||||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
model: nn.Module,
|
||||||
|
config: dict,
|
||||||
|
q_group_size: int,
|
||||||
|
q_bits: int,
|
||||||
|
quant_predicate: Optional[
|
||||||
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||||
|
] = None,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
"""
|
"""
|
||||||
Applies quantization to the model weights.
|
Applies quantization to the model weights.
|
||||||
@ -717,17 +739,37 @@ def quantize_model(
|
|||||||
config (dict): Model configuration.
|
config (dict): Model configuration.
|
||||||
q_group_size (int): Group size for quantization.
|
q_group_size (int): Group size for quantization.
|
||||||
q_bits (int): Bits per weight for quantization.
|
q_bits (int): Bits per weight for quantization.
|
||||||
|
quant_predicate (Callable): A callable that decides how
|
||||||
|
to quantize each layer based on the path.
|
||||||
|
Accepts the layer `path`, the `module` and the model `config`.
|
||||||
|
Returns either a bool to signify quantize/no quantize or
|
||||||
|
a dict of quantization parameters to pass to `to_quantized`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple: Tuple containing quantized weights and config.
|
Tuple: Tuple containing quantized weights and config.
|
||||||
"""
|
"""
|
||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
nn.quantize(model, q_group_size, q_bits)
|
|
||||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||||
|
|
||||||
|
# Add any custom quantization parameters to the config as we go
|
||||||
|
def _class_predicate(p, m):
|
||||||
|
bool_or_params = quant_predicate(p, m, config)
|
||||||
|
quantized_config["quantization"][p] = bool_or_params
|
||||||
|
return bool_or_params
|
||||||
|
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
q_group_size,
|
||||||
|
q_bits,
|
||||||
|
class_predicate=_class_predicate if quant_predicate else None,
|
||||||
|
)
|
||||||
# support hf model tree #957
|
# support hf model tree #957
|
||||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
bpw = compute_bits_per_weight(model)
|
||||||
|
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
||||||
|
|
||||||
return quantized_weights, quantized_config
|
return quantized_weights, quantized_config
|
||||||
|
|
||||||
|
|
||||||
@ -764,6 +806,9 @@ def convert(
|
|||||||
upload_repo: str = None,
|
upload_repo: str = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
dequantize: bool = False,
|
dequantize: bool = False,
|
||||||
|
quant_predicate: Optional[
|
||||||
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||||
|
] = None,
|
||||||
):
|
):
|
||||||
# Check the save path is empty
|
# Check the save path is empty
|
||||||
if isinstance(mlx_path, str):
|
if isinstance(mlx_path, str):
|
||||||
@ -789,7 +834,9 @@ def convert(
|
|||||||
if quantize:
|
if quantize:
|
||||||
print("[INFO] Quantizing")
|
print("[INFO] Quantizing")
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
weights, config = quantize_model(
|
||||||
|
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
||||||
|
)
|
||||||
|
|
||||||
if dequantize:
|
if dequantize:
|
||||||
print("[INFO] Dequantizing")
|
print("[INFO] Dequantizing")
|
||||||
|
Loading…
Reference in New Issue
Block a user