diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d4afd428..cee5676b 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -460,13 +460,17 @@ def load_model( if (quantization := config.get("quantization", None)) is not None: # Handle legacy models which may not have everything quantized def class_predicate(p, m): + # Handle custom per layer quantizations + if p in config["quantization"]: + return config["quantization"][p] if not hasattr(m, "to_quantized"): return False return f"{p}.scales" in weights nn.quantize( model, - **quantization, + group_size=quantization["group_size"], + bits=quantization["bits"], class_predicate=class_predicate, ) @@ -669,7 +673,13 @@ def save_weights( def quantize_model( - model: nn.Module, config: dict, q_group_size: int, q_bits: int + model: nn.Module, + config: dict, + q_group_size: int, + q_bits: int, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ) -> Tuple: """ Applies quantization to the model weights. @@ -679,13 +689,31 @@ def quantize_model( config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. + quant_predicate (Callable): A callable that decides how + to quantize each layer based on the path. + Accepts the layer `path`, the `module` and the model `config`. + Returns either a bool to signify quantize/no quantize or + a dict of quantization parameters to pass to `to_quantized`. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + + # Add any custom quantization parameters to the config as we go + def _class_predicate(p, m): + bool_or_params = quant_predicate(p, m, config) + if isinstance(bool_or_params, dict): + quantized_config["quantization"][p] = bool_or_params + return bool_or_params + + nn.quantize( + model, + q_group_size, + q_bits, + class_predicate=_class_predicate if quant_predicate else None, + ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) @@ -726,6 +754,9 @@ def convert( upload_repo: str = None, revision: Optional[str] = None, dequantize: bool = False, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -751,7 +782,9 @@ def convert( if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) - weights, config = quantize_model(model, config, q_group_size, q_bits) + weights, config = quantize_model( + model, config, q_group_size, q_bits, quant_predicate=quant_predicate + ) if dequantize: print("[INFO] Dequantizing")