diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 835cb482..8351ed1b 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -250,12 +250,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module: return model -def print_trainable_parameters(model): - def nparams(m): - if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): - return m.weight.size * (32 // m.bits) - return sum(v.size for _, v in tree_flatten(m.parameters())) +def nparams(module): + if hasattr(module, "bits"): + n = 0 if not hasattr(module, "bias") else module.bias.size + return n + module.weight.size * 32 // module.bits + return sum(v.size for _, v in tree_flatten(module.parameters())) + +def print_trainable_parameters(model): leaf_modules = tree_flatten( model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 86b786ce..66a106a1 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten, tree_reduce +from mlx.utils import tree_flatten, tree_map, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -24,7 +24,7 @@ from .models import cache from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model -from .tuner.utils import load_adapters +from .tuner.utils import load_adapters, nparams # Constants MODEL_REMAPPING = { @@ -127,6 +127,17 @@ def _get_classes(config: dict): return arch.Model, arch.ModelArgs +def compute_bits_per_weight(model): + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + model_params = sum(nparams(m) for _, m in leaf_modules) + return model_bytes * 8 / model_params + + def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, @@ -496,15 +507,20 @@ def load_model( weights = model.sanitize(weights) if (quantization := config.get("quantization", None)) is not None: - # Handle legacy models which may not have everything quantized + def class_predicate(p, m): + # Handle custom per layer quantizations + if p in config["quantization"]: + return config["quantization"][p] if not hasattr(m, "to_quantized"): return False + # Handle legacy models which may not have everything quantized return f"{p}.scales" in weights nn.quantize( model, - **quantization, + group_size=quantization["group_size"], + bits=quantization["bits"], class_predicate=class_predicate, ) @@ -707,7 +723,13 @@ def save_weights( def quantize_model( - model: nn.Module, config: dict, q_group_size: int, q_bits: int + model: nn.Module, + config: dict, + q_group_size: int, + q_bits: int, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ) -> Tuple: """ Applies quantization to the model weights. @@ -717,17 +739,37 @@ def quantize_model( config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. + quant_predicate (Callable): A callable that decides how + to quantize each layer based on the path. + Accepts the layer `path`, the `module` and the model `config`. + Returns either a bool to signify quantize/no quantize or + a dict of quantization parameters to pass to `to_quantized`. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + + # Add any custom quantization parameters to the config as we go + def _class_predicate(p, m): + bool_or_params = quant_predicate(p, m, config) + quantized_config["quantization"][p] = bool_or_params + return bool_or_params + + nn.quantize( + model, + q_group_size, + q_bits, + class_predicate=_class_predicate if quant_predicate else None, + ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) + bpw = compute_bits_per_weight(model) + print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") + return quantized_weights, quantized_config @@ -764,6 +806,9 @@ def convert( upload_repo: str = None, revision: Optional[str] = None, dequantize: bool = False, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -789,7 +834,9 @@ def convert( if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) - weights, config = quantize_model(model, config, q_group_size, q_bits) + weights, config = quantize_model( + model, config, q_group_size, q_bits, quant_predicate=quant_predicate + ) if dequantize: print("[INFO] Dequantizing")