mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
210 lines
6.0 KiB
Python
210 lines
6.0 KiB
Python
import argparse
|
|
import copy
|
|
import glob
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import transformers
|
|
from mlx.utils import tree_flatten
|
|
|
|
from .utils import get_model_path, linear_class_predicate, load
|
|
|
|
MAX_FILE_SIZE_GB = 15
|
|
|
|
|
|
def configure_parser() -> argparse.ArgumentParser:
|
|
"""
|
|
Configures and returns the argument parser for the script.
|
|
|
|
Returns:
|
|
argparse.ArgumentParser: Configured argument parser.
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert Hugging Face model to MLX format"
|
|
)
|
|
|
|
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
|
parser.add_argument(
|
|
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
|
)
|
|
parser.add_argument(
|
|
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
|
)
|
|
parser.add_argument(
|
|
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
|
)
|
|
parser.add_argument(
|
|
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
help="Type to save the parameters, ignored if -q is given.",
|
|
type=str,
|
|
choices=["float16", "bfloat16", "float32"],
|
|
default="float16",
|
|
)
|
|
parser.add_argument(
|
|
"--upload-repo",
|
|
help="The Hugging Face repo to upload the model to.",
|
|
type=str,
|
|
default=None,
|
|
)
|
|
return parser
|
|
|
|
|
|
def fetch_from_hub(
|
|
model_path: str,
|
|
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
|
model_path = get_model_path(model_path)
|
|
|
|
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
|
if not weight_files:
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
|
|
weights = {}
|
|
for wf in weight_files:
|
|
weights.update(mx.load(wf).items())
|
|
|
|
config = transformers.AutoConfig.from_pretrained(model_path)
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
|
|
|
return weights, config.to_dict(), tokenizer
|
|
|
|
|
|
def quantize_model(
|
|
weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int
|
|
) -> tuple:
|
|
"""
|
|
Applies quantization to the model weights.
|
|
|
|
Args:
|
|
weights (dict): Model weights.
|
|
config (dict): Model configuration.
|
|
hf_path (str): HF model path..
|
|
q_group_size (int): Group size for quantization.
|
|
q_bits (int): Bits per weight for quantization.
|
|
|
|
Returns:
|
|
tuple: Tuple containing quantized weights and config.
|
|
"""
|
|
quantized_config = copy.deepcopy(config)
|
|
model, _ = load(hf_path)
|
|
model.load_weights(list(weights.items()))
|
|
|
|
nn.QuantizedLinear.quantize_module(
|
|
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
|
)
|
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
|
|
return quantized_weights, quantized_config
|
|
|
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
"""
|
|
Splits the weights into smaller shards.
|
|
|
|
Args:
|
|
weights (dict): Model weights.
|
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
|
|
Returns:
|
|
list: List of weight shards.
|
|
"""
|
|
max_file_size_bytes = max_file_size_gb << 30
|
|
shards = []
|
|
shard, shard_size = {}, 0
|
|
for k, v in weights.items():
|
|
estimated_size = v.size * v.dtype.size
|
|
if shard_size + estimated_size > max_file_size_bytes:
|
|
shards.append(shard)
|
|
shard, shard_size = {}, 0
|
|
shard[k] = v
|
|
shard_size += estimated_size
|
|
shards.append(shard)
|
|
return shards
|
|
|
|
|
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
"""
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
Args:
|
|
path (str): Local path to the model.
|
|
upload_repo (str): Name of the HF repo to upload to.
|
|
hf_path (str): Path to the original Hugging Face model.
|
|
"""
|
|
import os
|
|
|
|
from huggingface_hub import HfApi, ModelCard, logging
|
|
|
|
card = ModelCard.load(hf_path)
|
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
|
card.text = f"""
|
|
# {upload_repo}
|
|
This model was converted to MLX format from [`{hf_path}`]().
|
|
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
|
## Use with mlx
|
|
|
|
```bash
|
|
pip install mlx-lm
|
|
```
|
|
|
|
```python
|
|
from mlx_lm import load, generate
|
|
|
|
model, tokenizer = load("{upload_repo}")
|
|
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
|
```
|
|
"""
|
|
card.save(os.path.join(path, "README.md"))
|
|
|
|
logging.set_verbosity_info()
|
|
|
|
api = HfApi()
|
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
api.upload_folder(
|
|
folder_path=path,
|
|
repo_id=upload_repo,
|
|
repo_type="model",
|
|
)
|
|
|
|
|
|
def convert(
|
|
hf_path: str,
|
|
mlx_path: str = "mlx_model",
|
|
quantize: bool = False,
|
|
q_group_size: int = 64,
|
|
q_bits: int = 4,
|
|
dtype: str = "float16",
|
|
upload_repo: str = None,
|
|
):
|
|
print("[INFO] Loading")
|
|
weights, config, tokenizer = fetch_from_hub(hf_path)
|
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
if quantize:
|
|
print("[INFO] Quantizing")
|
|
weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits)
|
|
|
|
mlx_path = Path(mlx_path)
|
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
|
shards = make_shards(weights)
|
|
for i, shard in enumerate(shards):
|
|
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
|
|
tokenizer.save_pretrained(mlx_path)
|
|
with open(mlx_path / "config.json", "w") as fid:
|
|
json.dump(config, fid, indent=4)
|
|
|
|
if upload_repo is not None:
|
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = configure_parser()
|
|
args = parser.parse_args()
|
|
convert(**vars(args))
|