From 6c3d4c8ba2da9d9352ca1c3b478aa6712fe8ac27 Mon Sep 17 00:00:00 2001 From: Alwin Arrasyid Date: Wed, 20 Mar 2024 09:50:08 +0700 Subject: [PATCH] add dequantize option to mlx_lm/convert.py (#547) --- llms/mlx_lm/convert.py | 7 +++++++ llms/mlx_lm/utils.py | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 7eac34aa..5f2f3adf 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -42,6 +42,13 @@ def configure_parser() -> argparse.ArgumentParser: type=str, default=None, ) + parser.add_argument( + "-d", + "--dequantize", + help="Dequantize a quantized model.", + action="store_true", + default=False, + ) return parser diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5f2a2f8c..7b0e2da7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,6 +19,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer # Local imports from .tuner.utils import apply_lora_layers +from .tuner.utils import dequantize as dequantize_model # Constants MODEL_REMAPPING = { @@ -587,6 +588,7 @@ def convert( dtype: str = "float16", upload_repo: str = None, revision: Optional[str] = None, + dequantize: bool = False, ): print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) @@ -596,11 +598,19 @@ def convert( dtype = mx.float16 if quantize else getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} + if quantize and dequantize: + raise ValueError("Choose either quantize or dequantize, not both.") + if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) weights, config = quantize_model(model, config, q_group_size, q_bits) + if dequantize: + print("[INFO] Dequantizing") + model = dequantize_model(model) + weights = dict(tree_flatten(model.parameters())) + if isinstance(mlx_path, str): mlx_path = Path(mlx_path)