mlx-examples/llms/mlx_lm/convert.py

132 lines
3.6 KiB
Python
Raw Normal View History

import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
from .utils import (
fetch_from_hub,
get_model_path,
linear_class_predicate,
save_weights,
upload_to_hub,
)
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> Tuple:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
2024-01-06 13:29:15 +08:00
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
save_weights(mlx_path, weights)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()
convert(**vars(args))