mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
add dequantize option to mlx_lm/convert.py (#547)
This commit is contained in:
parent
6f2fd5daea
commit
6c3d4c8ba2
@ -42,6 +42,13 @@ def configure_parser() -> argparse.ArgumentParser:
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--dequantize",
|
||||||
|
help="Dequantize a quantized model.",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
|||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .tuner.utils import apply_lora_layers
|
from .tuner.utils import apply_lora_layers
|
||||||
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
@ -587,6 +588,7 @@ def convert(
|
|||||||
dtype: str = "float16",
|
dtype: str = "float16",
|
||||||
upload_repo: str = None,
|
upload_repo: str = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
|
dequantize: bool = False,
|
||||||
):
|
):
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
model_path = get_model_path(hf_path, revision=revision)
|
model_path = get_model_path(hf_path, revision=revision)
|
||||||
@ -596,11 +598,19 @@ def convert(
|
|||||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
|
|
||||||
|
if quantize and dequantize:
|
||||||
|
raise ValueError("Choose either quantize or dequantize, not both.")
|
||||||
|
|
||||||
if quantize:
|
if quantize:
|
||||||
print("[INFO] Quantizing")
|
print("[INFO] Quantizing")
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||||
|
|
||||||
|
if dequantize:
|
||||||
|
print("[INFO] Dequantizing")
|
||||||
|
model = dequantize_model(model)
|
||||||
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
if isinstance(mlx_path, str):
|
if isinstance(mlx_path, str):
|
||||||
mlx_path = Path(mlx_path)
|
mlx_path = Path(mlx_path)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user