mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
add dequantize option to mlx_lm/convert.py (#547)
This commit is contained in:
parent
6f2fd5daea
commit
6c3d4c8ba2
@ -42,6 +42,13 @@ def configure_parser() -> argparse.ArgumentParser:
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dequantize",
|
||||
help="Dequantize a quantized model.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -19,6 +19,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .tuner.utils import apply_lora_layers
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
|
||||
# Constants
|
||||
MODEL_REMAPPING = {
|
||||
@ -587,6 +588,7 @@ def convert(
|
||||
dtype: str = "float16",
|
||||
upload_repo: str = None,
|
||||
revision: Optional[str] = None,
|
||||
dequantize: bool = False,
|
||||
):
|
||||
print("[INFO] Loading")
|
||||
model_path = get_model_path(hf_path, revision=revision)
|
||||
@ -596,11 +598,19 @@ def convert(
|
||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
|
||||
if quantize and dequantize:
|
||||
raise ValueError("Choose either quantize or dequantize, not both.")
|
||||
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
|
||||
if dequantize:
|
||||
print("[INFO] Dequantizing")
|
||||
model = dequantize_model(model)
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
if isinstance(mlx_path, str):
|
||||
mlx_path = Path(mlx_path)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user