2024-02-01 06:19:53 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
|
|
|
import argparse
|
2024-02-23 22:49:53 +08:00
|
|
|
import json
|
2024-02-01 06:19:53 +08:00
|
|
|
import shutil
|
|
|
|
from pathlib import Path
|
2024-02-23 22:49:53 +08:00
|
|
|
from typing import Any, Dict, Union
|
2024-02-01 06:19:53 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import torch
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
|
|
2024-02-23 22:49:53 +08:00
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
|
|
|
|
max_file_size_bytes = max_file_size_gb << 30
|
|
|
|
shards = []
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
for k, v in weights.items():
|
|
|
|
if shard_size + v.nbytes > max_file_size_bytes:
|
|
|
|
shards.append(shard)
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
shard[k] = v
|
|
|
|
shard_size += v.nbytes
|
|
|
|
shards.append(shard)
|
|
|
|
return shards
|
|
|
|
|
|
|
|
|
|
|
|
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
|
|
|
"""Save model weights into specified directory."""
|
|
|
|
if isinstance(save_path, str):
|
|
|
|
save_path = Path(save_path)
|
|
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
shards = make_shards(weights)
|
|
|
|
shards_count = len(shards)
|
|
|
|
shard_file_format = (
|
|
|
|
"model-{:05d}-of-{:05d}.safetensors"
|
|
|
|
if shards_count > 1
|
|
|
|
else "model.safetensors"
|
|
|
|
)
|
|
|
|
|
|
|
|
total_size = sum(v.nbytes for v in weights.values())
|
|
|
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
|
|
|
|
|
|
|
for i, shard in enumerate(shards):
|
|
|
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
|
|
|
shard_path = save_path / shard_name
|
|
|
|
|
|
|
|
mx.save_safetensors(str(shard_path), shard)
|
|
|
|
|
|
|
|
for weight_name in shard.keys():
|
|
|
|
index_data["weight_map"][weight_name] = shard_name
|
|
|
|
|
|
|
|
index_data["weight_map"] = {
|
|
|
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
|
|
|
}
|
|
|
|
|
|
|
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
|
|
|
json.dump(
|
|
|
|
index_data,
|
|
|
|
f,
|
|
|
|
indent=4,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-07-24 04:10:20 +08:00
|
|
|
def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path:
|
2024-02-01 06:19:53 +08:00
|
|
|
model_path = Path(path_or_hf_repo)
|
|
|
|
if not model_path.exists():
|
|
|
|
model_path = Path(
|
|
|
|
snapshot_download(
|
|
|
|
repo_id=path_or_hf_repo,
|
|
|
|
allow_patterns=[
|
|
|
|
"*.bin",
|
|
|
|
"*.json",
|
|
|
|
"*.txt",
|
|
|
|
],
|
2024-07-24 04:10:20 +08:00
|
|
|
force_download=force_download,
|
2024-02-01 06:19:53 +08:00
|
|
|
)
|
|
|
|
)
|
|
|
|
return model_path
|
|
|
|
|
|
|
|
|
|
|
|
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
|
|
|
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
|
|
|
|
a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype))
|
|
|
|
return mx.array(a.numpy(), getattr(mx, dtype))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Download and Convert (OpenAI) CLIP weights to MLX"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--hf-repo",
|
|
|
|
type=str,
|
|
|
|
default="openai/clip-vit-base-patch32",
|
|
|
|
help="Hugging Face repository name.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--mlx-path",
|
|
|
|
type=str,
|
|
|
|
default="mlx_model",
|
|
|
|
help="Path to save the MLX model.",
|
|
|
|
)
|
2024-02-23 22:49:53 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--dtype",
|
|
|
|
help="The data type to save the converted model.",
|
|
|
|
type=str,
|
|
|
|
default="float32",
|
|
|
|
)
|
2024-07-24 04:10:20 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"-f",
|
|
|
|
"--force-download",
|
|
|
|
help="Force download the model from Hugging Face.",
|
|
|
|
action="store_true",
|
|
|
|
)
|
2024-02-01 06:19:53 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2024-07-24 04:10:20 +08:00
|
|
|
torch_path = get_model_path(args.hf_repo, args.force_download)
|
2024-02-01 06:19:53 +08:00
|
|
|
mlx_path = Path(args.mlx_path)
|
|
|
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
print("[INFO] Loading")
|
2025-02-22 22:08:54 +08:00
|
|
|
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
|
2024-02-01 06:19:53 +08:00
|
|
|
print("[INFO] Converting")
|
2024-02-23 22:49:53 +08:00
|
|
|
mlx_weights = {
|
|
|
|
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
|
|
|
}
|
2024-02-01 06:19:53 +08:00
|
|
|
print("[INFO] Saving")
|
2024-02-23 22:49:53 +08:00
|
|
|
save_weights(mlx_path, mlx_weights)
|
2024-02-01 06:19:53 +08:00
|
|
|
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
|
|
|
|
shutil.copyfile(
|
|
|
|
str(torch_path / f"{fn}"),
|
|
|
|
str(mlx_path / f"{fn}"),
|
|
|
|
)
|