add dequantize option to mlx_lm/convert.py (#547)

This commit is contained in:
Alwin Arrasyid 2024-03-20 09:50:08 +07:00 committed by GitHub
parent 6f2fd5daea
commit 6c3d4c8ba2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 0 deletions

View File

@ -42,6 +42,13 @@ def configure_parser() -> argparse.ArgumentParser:
type=str, type=str,
default=None, default=None,
) )
parser.add_argument(
"-d",
"--dequantize",
help="Dequantize a quantized model.",
action="store_true",
default=False,
)
return parser return parser

View File

@ -19,6 +19,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
# Local imports # Local imports
from .tuner.utils import apply_lora_layers from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
# Constants # Constants
MODEL_REMAPPING = { MODEL_REMAPPING = {
@ -587,6 +588,7 @@ def convert(
dtype: str = "float16", dtype: str = "float16",
upload_repo: str = None, upload_repo: str = None,
revision: Optional[str] = None, revision: Optional[str] = None,
dequantize: bool = False,
): ):
print("[INFO] Loading") print("[INFO] Loading")
model_path = get_model_path(hf_path, revision=revision) model_path = get_model_path(hf_path, revision=revision)
@ -596,11 +598,19 @@ def convert(
dtype = mx.float16 if quantize else getattr(mx, dtype) dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()} weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize and dequantize:
raise ValueError("Choose either quantize or dequantize, not both.")
if quantize: if quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits) weights, config = quantize_model(model, config, q_group_size, q_bits)
if dequantize:
print("[INFO] Dequantizing")
model = dequantize_model(model)
weights = dict(tree_flatten(model.parameters()))
if isinstance(mlx_path, str): if isinstance(mlx_path, str):
mlx_path = Path(mlx_path) mlx_path = Path(mlx_path)