chore(clip): update the clip example to make it compatible with HF format (#472)

* chore(clip): update the clip model to be HF format

* Update clip/convert.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: address comments

* chore: rename ClipVisionModel and ClipTextModel

* chore: add output hidden_states support

* chore: remove custom conv2d and apply weight transpose during weight sanitizing

* Update clip/model.py

* Update clip/model.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anchen
2024-02-24 01:49:53 +11:00
committed by GitHub
parent f24edfa9dc
commit 47dd6bd17f
4 changed files with 267 additions and 104 deletions

View File

@@ -1,15 +1,68 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import shutil
from pathlib import Path
from typing import Tuple
from typing import Any, Dict, Union
import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += v.nbytes
shards.append(shard)
return shards
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name
mx.save_safetensors(str(shard_path), shard)
for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(
index_data,
f,
indent=4,
)
def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
@@ -32,44 +85,6 @@ def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
return mx.array(a.numpy(), getattr(mx, dtype))
def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
key = key.replace("embeddings.", "")
key = key.replace("encoder.", "")
key = key.replace("position_embedding.weight", "position_embedding")
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
if "layer_norm1." in key:
key = key.replace("layer_norm1.", "ln1.")
if "layer_norm2." in key:
key = key.replace("layer_norm2.", "ln2.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
# Fix layernorm typo
if "pre_layrnorm" in key:
# Fix typo in weights :)
key = key.replace("pre_layrnorm", "pre_layernorm")
if "patch_embedding.weight" in key:
# Initially, value: [out_channels, in_channels, kH, KW].
# We want [out_channels, kH, KW, in_channels]
value = value.permute(0, 2, 3, 1)
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))
def should_keep_weight(key: str):
return not ("position_ids" in key)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download and Convert (OpenAI) CLIP weights to MLX"
@@ -86,7 +101,12 @@ if __name__ == "__main__":
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--dtype",
help="The data type to save the converted model.",
type=str,
default="float32",
)
args = parser.parse_args()
torch_path = get_model_path(args.hf_repo)
@@ -96,10 +116,11 @@ if __name__ == "__main__":
print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin")
print("[INFO] Converting")
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items())
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
mlx_weights = {
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
}
print("[INFO] Saving")
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
save_weights(mlx_path, mlx_weights)
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
shutil.copyfile(
str(torch_path / f"{fn}"),