2024-01-04 07:13:26 +08:00
|
|
|
import argparse
|
|
|
|
import copy
|
|
|
|
import glob
|
|
|
|
import json
|
2024-01-24 00:44:37 +08:00
|
|
|
import shutil
|
2024-01-04 07:13:26 +08:00
|
|
|
from pathlib import Path
|
2024-01-24 00:44:37 +08:00
|
|
|
from typing import Tuple
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
from mlx.utils import tree_flatten
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
from .utils import (
|
|
|
|
fetch_from_hub,
|
|
|
|
get_model_path,
|
|
|
|
linear_class_predicate,
|
|
|
|
save_weights,
|
|
|
|
upload_to_hub,
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
def configure_parser() -> argparse.ArgumentParser:
|
|
|
|
"""
|
|
|
|
Configures and returns the argument parser for the script.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
argparse.ArgumentParser: Configured argument parser.
|
|
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Convert Hugging Face model to MLX format"
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--dtype",
|
|
|
|
help="Type to save the parameters, ignored if -q is given.",
|
|
|
|
type=str,
|
|
|
|
choices=["float16", "bfloat16", "float32"],
|
|
|
|
default="float16",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--upload-repo",
|
|
|
|
help="The Hugging Face repo to upload the model to.",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
)
|
|
|
|
return parser
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def quantize_model(
|
2024-01-20 13:07:21 +08:00
|
|
|
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
2024-01-24 00:44:37 +08:00
|
|
|
) -> Tuple:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
Applies quantization to the model weights.
|
|
|
|
|
|
|
|
Args:
|
2024-01-20 13:07:21 +08:00
|
|
|
model (nn.Module): The model to be quantized.
|
2024-01-12 04:29:12 +08:00
|
|
|
config (dict): Model configuration.
|
2024-01-13 02:25:56 +08:00
|
|
|
q_group_size (int): Group size for quantization.
|
|
|
|
q_bits (int): Bits per weight for quantization.
|
2024-01-04 07:13:26 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
Returns:
|
2024-01-24 00:44:37 +08:00
|
|
|
Tuple: Tuple containing quantized weights and config.
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
quantized_config = copy.deepcopy(config)
|
2024-01-04 07:13:26 +08:00
|
|
|
|
2024-01-15 23:18:14 +08:00
|
|
|
nn.QuantizedLinear.quantize_module(
|
|
|
|
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
|
|
|
)
|
|
|
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
2024-01-04 07:13:26 +08:00
|
|
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
|
|
|
return quantized_weights, quantized_config
|
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def convert(
|
|
|
|
hf_path: str,
|
|
|
|
mlx_path: str = "mlx_model",
|
|
|
|
quantize: bool = False,
|
|
|
|
q_group_size: int = 64,
|
|
|
|
q_bits: int = 4,
|
|
|
|
dtype: str = "float16",
|
|
|
|
upload_repo: str = None,
|
|
|
|
):
|
2024-01-04 07:13:26 +08:00
|
|
|
print("[INFO] Loading")
|
2024-01-24 00:44:37 +08:00
|
|
|
model_path = get_model_path(hf_path)
|
2024-02-21 05:36:55 +08:00
|
|
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
2024-01-13 02:25:56 +08:00
|
|
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
2024-01-06 13:29:15 +08:00
|
|
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
if quantize:
|
2024-01-04 07:13:26 +08:00
|
|
|
print("[INFO] Quantizing")
|
2024-01-20 13:07:21 +08:00
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
2024-01-04 07:13:26 +08:00
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
if isinstance(mlx_path, str):
|
|
|
|
mlx_path = Path(mlx_path)
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
del model
|
|
|
|
save_weights(mlx_path, weights, donate_weights=True)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
py_files = glob.glob(str(model_path / "*.py"))
|
|
|
|
for file in py_files:
|
|
|
|
shutil.copy(file, mlx_path)
|
|
|
|
|
2024-01-04 07:13:26 +08:00
|
|
|
tokenizer.save_pretrained(mlx_path)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
2024-01-04 07:13:26 +08:00
|
|
|
with open(mlx_path / "config.json", "w") as fid:
|
|
|
|
json.dump(config, fid, indent=4)
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
if upload_repo is not None:
|
|
|
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = configure_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
convert(**vars(args))
|