diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 7eac34aa..5f2f3adf 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -42,6 +42,13 @@ def configure_parser() -> argparse.ArgumentParser: type=str, default=None, ) + parser.add_argument( + "-d", + "--dequantize", + help="Dequantize a quantized model.", + action="store_true", + default=False, + ) return parser diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5f2a2f8c..7b0e2da7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,6 +19,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer # Local imports from .tuner.utils import apply_lora_layers +from .tuner.utils import dequantize as dequantize_model # Constants MODEL_REMAPPING = { @@ -587,6 +588,7 @@ def convert( dtype: str = "float16", upload_repo: str = None, revision: Optional[str] = None, + dequantize: bool = False, ): print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) @@ -596,11 +598,19 @@ def convert( dtype = mx.float16 if quantize else getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} + if quantize and dequantize: + raise ValueError("Choose either quantize or dequantize, not both.") + if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) weights, config = quantize_model(model, config, q_group_size, q_bits) + if dequantize: + print("[INFO] Dequantizing") + model = dequantize_model(model) + weights = dict(tree_flatten(model.parameters())) + if isinstance(mlx_path, str): mlx_path = Path(mlx_path)