chore(clip): update the clip example to make it compatible with HF format (#472)

* chore(clip): update the clip model to be HF format

* Update clip/convert.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: address comments

* chore: rename ClipVisionModel and ClipTextModel

* chore: add output hidden_states support

* chore: remove custom conv2d and apply weight transpose during weight sanitizing

* Update clip/model.py

* Update clip/model.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anchen 2024-02-24 01:49:53 +11:00 committed by GitHub
parent f24edfa9dc
commit 47dd6bd17f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 267 additions and 104 deletions

View File

@ -1,15 +1,68 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import shutil
from pathlib import Path
from typing import Tuple
from typing import Any, Dict, Union
import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += v.nbytes
shards.append(shard)
return shards
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name
mx.save_safetensors(str(shard_path), shard)
for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(
index_data,
f,
indent=4,
)
def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
@ -32,44 +85,6 @@ def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
return mx.array(a.numpy(), getattr(mx, dtype))
def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
key = key.replace("embeddings.", "")
key = key.replace("encoder.", "")
key = key.replace("position_embedding.weight", "position_embedding")
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
if "layer_norm1." in key:
key = key.replace("layer_norm1.", "ln1.")
if "layer_norm2." in key:
key = key.replace("layer_norm2.", "ln2.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
# Fix layernorm typo
if "pre_layrnorm" in key:
# Fix typo in weights :)
key = key.replace("pre_layrnorm", "pre_layernorm")
if "patch_embedding.weight" in key:
# Initially, value: [out_channels, in_channels, kH, KW].
# We want [out_channels, kH, KW, in_channels]
value = value.permute(0, 2, 3, 1)
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))
def should_keep_weight(key: str):
return not ("position_ids" in key)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download and Convert (OpenAI) CLIP weights to MLX"
@ -86,7 +101,12 @@ if __name__ == "__main__":
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--dtype",
help="The data type to save the converted model.",
type=str,
default="float32",
)
args = parser.parse_args()
torch_path = get_model_path(args.hf_repo)
@ -96,10 +116,11 @@ if __name__ == "__main__":
print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin")
print("[INFO] Converting")
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items())
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
mlx_weights = {
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
}
print("[INFO] Saving")
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
save_weights(mlx_path, mlx_weights)
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
shutil.copyfile(
str(torch_path / f"{fn}"),

View File

@ -1,9 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
import glob
import json
import logging
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from typing import Optional, Union
import mlx.core as mx
import mlx.nn as nn
@ -16,6 +19,7 @@ from mlx.utils import tree_flatten
class CLIPVisionOutput:
pooler_output: mx.array
last_hidden_state: mx.array
hidden_states: Optional[mx.array]
@dataclass
@ -41,6 +45,7 @@ class CLIPTextConfig:
num_attention_heads: int
max_position_embeddings: int
vocab_size: int
layer_norm_eps: float
@dataclass
@ -52,6 +57,7 @@ class CLIPVisionConfig:
num_channels: int
image_size: int
patch_size: int
layer_norm_eps: float
@dataclass
@ -75,50 +81,133 @@ def clip_loss(logits: mx.array) -> mx.array:
return (caption_loss + image_loss) / 2.0
class CLIPEncoderLayer(nn.TransformerEncoderLayer):
class Attention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
def __call__(self, queries, keys, values, mask=None):
queries = self.q_proj(queries)
keys = self.k_proj(keys)
values = self.v_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
class MLP(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
self.activation_fn = quick_gelu
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def __call__(self, x: mx.array) -> mx.array:
x = self.activation_fn(self.fc1(x))
x = self.fc2(x)
return x
class EncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int):
super().__init__(
dims=hidden_dim,
mlp_dims=intermediate_dim,
num_heads=num_heads,
activation=quick_gelu,
norm_first=True,
)
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.embed_dim = config.hidden_size
# Add biases to the attention projections
self.attention = nn.MultiHeadAttention(hidden_dim, num_heads, bias=True)
self.self_attn = Attention(
config.hidden_size, config.num_attention_heads, bias=True
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
y = self.layer_norm1(x)
y = self.self_attn(y, y, y, mask)
x = x + y
y = self.layer_norm2(x)
y = self.mlp(y)
return x + y
class CLIPTextModel(nn.Module):
class TextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
def __call__(self, x: mx.array) -> mx.array:
embeddings = self.token_embedding(x)
embeddings += self.position_embedding.weight[: x.shape[1]]
return embeddings
class Encoder(nn.Module):
def __init__(self, config: CLIPTextConfig):
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
class ClipTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = mx.zeros(
(config.max_position_embeddings, config.hidden_size)
)
self.layers = [
CLIPEncoderLayer(
config.hidden_size, config.intermediate_size, config.num_attention_heads
)
for _ in range(config.num_hidden_layers)
]
self.embeddings = TextEmbeddings(config)
self.encoder = Encoder(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
def _embed(self, x: mx.array) -> mx.array:
embeddings = self.token_embedding(x)
embeddings += self.position_embedding[: x.shape[1]]
return embeddings
def __call__(self, x: mx.array) -> CLIPTextOutput:
B, N = x.shape
eot_tokens = mx.argmax(x, axis=-1)
x = self._embed(x)
x = self.embeddings(x)
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
for l in self.layers:
for l in self.encoder.layers:
x = l(x, mask)
last_hidden_state = self.final_layer_norm(x)
pooler_output = last_hidden_state[mx.arange(B), eot_tokens]
@ -128,33 +217,29 @@ class CLIPTextModel(nn.Module):
)
class CLIPVisionModel(nn.Module):
"""Implements the vision encoder transformer from CLIP."""
class VisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = mx.zeros((config.hidden_size,))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
num_patches = (config.image_size // config.patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = mx.zeros((num_positions, config.hidden_size))
self.pre_layernorm = nn.LayerNorm(config.hidden_size)
self.layers = [
CLIPEncoderLayer(
config.hidden_size, config.intermediate_size, config.num_attention_heads
)
for _ in range(config.num_hidden_layers)
]
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def _embed(self, x: mx.array) -> mx.array:
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def __call__(self, x: mx.array) -> mx.array:
batch_size = x.shape[0]
# Patchify using conv:
# [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim]
@ -170,25 +255,48 @@ class CLIPVisionModel(nn.Module):
# [batch_size, num_patches + 1, embed_dim]
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
# Add positional encoding
embeddings += self.position_embedding
embeddings += self.position_embedding.weight
return embeddings
def __call__(self, x: mx.array) -> CLIPVisionOutput:
x = self._embed(x)
x = self.pre_layernorm(x)
for l in self.layers:
class ClipVisionModel(nn.Module):
"""Implements the vision encoder transformer from CLIP."""
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.embeddings = VisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
self.encoder = Encoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def __call__(
self,
x: mx.array,
output_hidden_states: Optional[bool] = None,
) -> CLIPVisionOutput:
x = self.embeddings(x)
x = self.pre_layrnorm(x)
encoder_states = (x,) if output_hidden_states else None
for l in self.encoder.layers:
x = l(x, mask=None)
if output_hidden_states:
encoder_states = encoder_states + (x,)
# Extract <CLS> token embedding
pooler_output = self.post_layernorm(x[:, 0, :])
return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x)
return CLIPVisionOutput(
pooler_output=pooler_output,
last_hidden_state=x,
hidden_states=encoder_states,
)
class CLIPModel(nn.Module):
def __init__(self, config: CLIPConfig):
self.text_model = CLIPTextModel(config.text_config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.text_model = ClipTextModel(config.text_config)
self.vision_model = ClipVisionModel(config.vision_config)
text_embed_dim = config.text_config.hidden_size
vision_embed_dim = config.vision_config.hidden_size
@ -259,6 +367,7 @@ class CLIPModel(nn.Module):
num_attention_heads=text_config["num_attention_heads"],
max_position_embeddings=text_config["max_position_embeddings"],
vocab_size=text_config["vocab_size"],
layer_norm_eps=text_config["layer_norm_eps"],
)
vision_config = config["vision_config"]
@ -271,6 +380,7 @@ class CLIPModel(nn.Module):
num_channels=3,
image_size=vision_config["image_size"],
patch_size=vision_config["patch_size"],
layer_norm_eps=vision_config["layer_norm_eps"],
)
config = CLIPConfig(
@ -279,5 +389,31 @@ class CLIPModel(nn.Module):
projection_dim=config["projection_dim"],
)
model = CLIPModel(config)
model.load_weights(str(path / "weights.npz"))
weight_files = glob.glob(str(path / "*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {path}")
raise FileNotFoundError(f"No safetensors found in {path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
weights = model.sanitize(weights)
model.load_weights(list(weights.items()))
return model
@staticmethod
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if "position_ids" in k:
# Remove unused position_ids
continue
elif "patch_embedding.weight" in k:
# pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW]
# mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels]
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
else:
sanitized_weights[k] = v
return sanitized_weights

View File

@ -3,4 +3,4 @@ numpy
transformers
torch
huggingface_hub
Pillow
Pillow

View File

@ -86,12 +86,14 @@ class TestCLIP(unittest.TestCase):
with torch.inference_mode():
# Get expected
x_tc = torch.tensor(x)
expected_out = self.hf_clip.vision_model(x_tc)
expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True)
expected_last_hidden = expected_out.last_hidden_state.numpy()
expected_pooler_output = expected_out.pooler_output.numpy()
expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states]
# Test MLX vision encoder
out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1)))
out = self.mx_clip.vision_model(
mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True
)
self.assertTrue(
np.allclose(
out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
@ -102,6 +104,10 @@ class TestCLIP(unittest.TestCase):
out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
),
)
for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states):
self.assertTrue(
np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3),
)
def test_clip_model(self):
image_input = self.hf_image_proc(