mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
chore(clip): update the clip example to make it compatible with HF format (#472)
* chore(clip): update the clip model to be HF format * Update clip/convert.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: address comments * chore: rename ClipVisionModel and ClipTextModel * chore: add output hidden_states support * chore: remove custom conv2d and apply weight transpose during weight sanitizing * Update clip/model.py * Update clip/model.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
107
clip/convert.py
107
clip/convert.py
@@ -1,15 +1,68 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += v.nbytes
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||
"""Save model weights into specified directory."""
|
||||
if isinstance(save_path, str):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shards = make_shards(weights)
|
||||
shards_count = len(shards)
|
||||
shard_file_format = (
|
||||
"model-{:05d}-of-{:05d}.safetensors"
|
||||
if shards_count > 1
|
||||
else "model.safetensors"
|
||||
)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
shard_path = save_path / shard_name
|
||||
|
||||
mx.save_safetensors(str(shard_path), shard)
|
||||
|
||||
for weight_name in shard.keys():
|
||||
index_data["weight_map"][weight_name] = shard_name
|
||||
|
||||
index_data["weight_map"] = {
|
||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||
}
|
||||
|
||||
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(
|
||||
index_data,
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
@@ -32,44 +85,6 @@ def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
||||
return mx.array(a.numpy(), getattr(mx, dtype))
|
||||
|
||||
|
||||
def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
|
||||
key = key.replace("embeddings.", "")
|
||||
key = key.replace("encoder.", "")
|
||||
key = key.replace("position_embedding.weight", "position_embedding")
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
if "layer_norm1." in key:
|
||||
key = key.replace("layer_norm1.", "ln1.")
|
||||
if "layer_norm2." in key:
|
||||
key = key.replace("layer_norm2.", "ln2.")
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
# Fix layernorm typo
|
||||
if "pre_layrnorm" in key:
|
||||
# Fix typo in weights :)
|
||||
key = key.replace("pre_layrnorm", "pre_layernorm")
|
||||
if "patch_embedding.weight" in key:
|
||||
# Initially, value: [out_channels, in_channels, kH, KW].
|
||||
# We want [out_channels, kH, KW, in_channels]
|
||||
value = value.permute(0, 2, 3, 1)
|
||||
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))
|
||||
|
||||
|
||||
def should_keep_weight(key: str):
|
||||
return not ("position_ids" in key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download and Convert (OpenAI) CLIP weights to MLX"
|
||||
@@ -86,7 +101,12 @@ if __name__ == "__main__":
|
||||
default="mlx_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="The data type to save the converted model.",
|
||||
type=str,
|
||||
default="float32",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_path = get_model_path(args.hf_repo)
|
||||
@@ -96,10 +116,11 @@ if __name__ == "__main__":
|
||||
print("[INFO] Loading")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin")
|
||||
print("[INFO] Converting")
|
||||
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items())
|
||||
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
|
||||
mlx_weights = {
|
||||
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
||||
}
|
||||
print("[INFO] Saving")
|
||||
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
|
||||
save_weights(mlx_path, mlx_weights)
|
||||
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
|
||||
shutil.copyfile(
|
||||
str(torch_path / f"{fn}"),
|
||||
|
||||
Reference in New Issue
Block a user