diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 9bac77a5..38cee117 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -29,6 +29,12 @@ def configure_parser() -> argparse.ArgumentParser: parser.add_argument( "--q-bits", help="Bits per weight for quantization.", type=int, default=4 ) + parser.add_argument( + "--q-type", + choices=["affine", "affine-packed"], + default="affine", + help="The type of quantization to apply", + ) parser.add_argument( "--dtype", help="Type to save the non-quantized parameters.", diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4d69115e..fcdc38c7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -528,6 +528,7 @@ def load_model( model, group_size=quantization["group_size"], bits=quantization["bits"], + quantization_type=quantization["quantization_type"], class_predicate=class_predicate, ) @@ -737,6 +738,7 @@ def quantize_model( config: dict, q_group_size: int, q_bits: int, + q_type: str, quant_predicate: Optional[ Callable[[str, nn.Module, dict], Union[bool, dict]] ] = None, @@ -749,6 +751,7 @@ def quantize_model( config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. + q_type (str): Quantization type quant_predicate (Callable): A callable that decides how to quantize each layer based on the path. Accepts the layer `path`, the `module` and the model `config`. @@ -759,11 +762,25 @@ def quantize_model( Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + quantized_config["quantization"] = { + "group_size": q_group_size, + "bits": q_bits, + "quantization_type": q_type, + } # Add any custom quantization parameters to the config as we go def _class_predicate(p, m): - bool_or_params = quant_predicate(p, m, config) + if quant_predicate: + bool_or_params = quant_predicate(p, m, config) + else: + if isinstance(m, nn.Embedding): + bool_or_params = { + "group_size": q_group_size, + "bits": q_bits, + "quantization_type": "affine", + } + else: + bool_or_params = hasattr(m, "to_quantized") quantized_config["quantization"][p] = bool_or_params return bool_or_params @@ -771,7 +788,8 @@ def quantize_model( model, q_group_size, q_bits, - class_predicate=_class_predicate if quant_predicate else None, + quantization_type=q_type, + class_predicate=_class_predicate, ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] @@ -812,6 +830,7 @@ def convert( quantize: bool = False, q_group_size: int = 64, q_bits: int = 4, + q_type: str = "affine", dtype: str = "float16", upload_repo: str = None, revision: Optional[str] = None, @@ -845,7 +864,12 @@ def convert( print("[INFO] Quantizing") model.load_weights(list(weights.items())) weights, config = quantize_model( - model, config, q_group_size, q_bits, quant_predicate=quant_predicate + model, + config, + q_group_size, + q_bits, + q_type=q_type, + quant_predicate=quant_predicate, ) if dequantize: