chore(clip): update the clip example to make it compatible with HF format (#472)

* chore(clip): update the clip model to be HF format

* Update clip/convert.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: address comments

* chore: rename ClipVisionModel and ClipTextModel

* chore: add output hidden_states support

* chore: remove custom conv2d and apply weight transpose during weight sanitizing

* Update clip/model.py

* Update clip/model.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anchen 2024-02-24 01:49:53 +11:00 committed by GitHub
parent f24edfa9dc
commit 47dd6bd17f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 267 additions and 104 deletions

View File

@ -1,15 +1,68 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Any, Dict, Union
import mlx.core as mx import mlx.core as mx
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += v.nbytes
shards.append(shard)
return shards
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name
mx.save_safetensors(str(shard_path), shard)
for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(
index_data,
f,
indent=4,
)
def get_model_path(path_or_hf_repo: str) -> Path: def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path(path_or_hf_repo) model_path = Path(path_or_hf_repo)
if not model_path.exists(): if not model_path.exists():
@ -32,44 +85,6 @@ def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
return mx.array(a.numpy(), getattr(mx, dtype)) return mx.array(a.numpy(), getattr(mx, dtype))
def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
key = key.replace("embeddings.", "")
key = key.replace("encoder.", "")
key = key.replace("position_embedding.weight", "position_embedding")
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
if "layer_norm1." in key:
key = key.replace("layer_norm1.", "ln1.")
if "layer_norm2." in key:
key = key.replace("layer_norm2.", "ln2.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
# Fix layernorm typo
if "pre_layrnorm" in key:
# Fix typo in weights :)
key = key.replace("pre_layrnorm", "pre_layernorm")
if "patch_embedding.weight" in key:
# Initially, value: [out_channels, in_channels, kH, KW].
# We want [out_channels, kH, KW, in_channels]
value = value.permute(0, 2, 3, 1)
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))
def should_keep_weight(key: str):
return not ("position_ids" in key)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Download and Convert (OpenAI) CLIP weights to MLX" description="Download and Convert (OpenAI) CLIP weights to MLX"
@ -86,7 +101,12 @@ if __name__ == "__main__":
default="mlx_model", default="mlx_model",
help="Path to save the MLX model.", help="Path to save the MLX model.",
) )
parser.add_argument(
"--dtype",
help="The data type to save the converted model.",
type=str,
default="float32",
)
args = parser.parse_args() args = parser.parse_args()
torch_path = get_model_path(args.hf_repo) torch_path = get_model_path(args.hf_repo)
@ -96,10 +116,11 @@ if __name__ == "__main__":
print("[INFO] Loading") print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin") torch_weights = torch.load(torch_path / "pytorch_model.bin")
print("[INFO] Converting") print("[INFO] Converting")
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items()) mlx_weights = {
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)} k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
}
print("[INFO] Saving") print("[INFO] Saving")
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights) save_weights(mlx_path, mlx_weights)
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]: for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
shutil.copyfile( shutil.copyfile(
str(torch_path / f"{fn}"), str(torch_path / f"{fn}"),

View File

@ -1,9 +1,12 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import glob
import json import json
import logging
import math
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -16,6 +19,7 @@ from mlx.utils import tree_flatten
class CLIPVisionOutput: class CLIPVisionOutput:
pooler_output: mx.array pooler_output: mx.array
last_hidden_state: mx.array last_hidden_state: mx.array
hidden_states: Optional[mx.array]
@dataclass @dataclass
@ -41,6 +45,7 @@ class CLIPTextConfig:
num_attention_heads: int num_attention_heads: int
max_position_embeddings: int max_position_embeddings: int
vocab_size: int vocab_size: int
layer_norm_eps: float
@dataclass @dataclass
@ -52,6 +57,7 @@ class CLIPVisionConfig:
num_channels: int num_channels: int
image_size: int image_size: int
patch_size: int patch_size: int
layer_norm_eps: float
@dataclass @dataclass
@ -75,50 +81,133 @@ def clip_loss(logits: mx.array) -> mx.array:
return (caption_loss + image_loss) / 2.0 return (caption_loss + image_loss) / 2.0
class CLIPEncoderLayer(nn.TransformerEncoderLayer): class Attention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
def __call__(self, queries, keys, values, mask=None):
queries = self.q_proj(queries)
keys = self.k_proj(keys)
values = self.v_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
class MLP(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
self.activation_fn = quick_gelu
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def __call__(self, x: mx.array) -> mx.array:
x = self.activation_fn(self.fc1(x))
x = self.fc2(x)
return x
class EncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP.""" """The transformer encoder layer from CLIP."""
def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int): def __init__(self, config: CLIPTextConfig):
super().__init__( super().__init__()
dims=hidden_dim, self.embed_dim = config.hidden_size
mlp_dims=intermediate_dim,
num_heads=num_heads,
activation=quick_gelu,
norm_first=True,
)
# Add biases to the attention projections # Add biases to the attention projections
self.attention = nn.MultiHeadAttention(hidden_dim, num_heads, bias=True) self.self_attn = Attention(
config.hidden_size, config.num_attention_heads, bias=True
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
y = self.layer_norm1(x)
y = self.self_attn(y, y, y, mask)
x = x + y
y = self.layer_norm2(x)
y = self.mlp(y)
return x + y
class CLIPTextModel(nn.Module): class TextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
def __call__(self, x: mx.array) -> mx.array:
embeddings = self.token_embedding(x)
embeddings += self.position_embedding.weight[: x.shape[1]]
return embeddings
class Encoder(nn.Module):
def __init__(self, config: CLIPTextConfig):
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
class ClipTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP.""" """Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextConfig): def __init__(self, config: CLIPTextConfig):
super().__init__() super().__init__()
self.embeddings = TextEmbeddings(config)
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.encoder = Encoder(config)
self.position_embedding = mx.zeros(
(config.max_position_embeddings, config.hidden_size)
)
self.layers = [
CLIPEncoderLayer(
config.hidden_size, config.intermediate_size, config.num_attention_heads
)
for _ in range(config.num_hidden_layers)
]
self.final_layer_norm = nn.LayerNorm(config.hidden_size) self.final_layer_norm = nn.LayerNorm(config.hidden_size)
def _embed(self, x: mx.array) -> mx.array:
embeddings = self.token_embedding(x)
embeddings += self.position_embedding[: x.shape[1]]
return embeddings
def __call__(self, x: mx.array) -> CLIPTextOutput: def __call__(self, x: mx.array) -> CLIPTextOutput:
B, N = x.shape B, N = x.shape
eot_tokens = mx.argmax(x, axis=-1) eot_tokens = mx.argmax(x, axis=-1)
x = self._embed(x) x = self.embeddings(x)
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype) mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
for l in self.layers: for l in self.encoder.layers:
x = l(x, mask) x = l(x, mask)
last_hidden_state = self.final_layer_norm(x) last_hidden_state = self.final_layer_norm(x)
pooler_output = last_hidden_state[mx.arange(B), eot_tokens] pooler_output = last_hidden_state[mx.arange(B), eot_tokens]
@ -128,33 +217,29 @@ class CLIPTextModel(nn.Module):
) )
class CLIPVisionModel(nn.Module): class VisionEmbeddings(nn.Module):
"""Implements the vision encoder transformer from CLIP."""
def __init__(self, config: CLIPVisionConfig): def __init__(self, config: CLIPVisionConfig):
super().__init__() super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = mx.zeros((config.hidden_size,)) self.class_embedding = mx.zeros((config.hidden_size,))
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels, in_channels=config.num_channels,
out_channels=config.hidden_size, out_channels=self.embed_dim,
kernel_size=config.patch_size, kernel_size=self.patch_size,
stride=config.patch_size, stride=self.patch_size,
bias=False, bias=False,
) )
num_patches = (config.image_size // config.patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = mx.zeros((num_positions, config.hidden_size))
self.pre_layernorm = nn.LayerNorm(config.hidden_size)
self.layers = [
CLIPEncoderLayer(
config.hidden_size, config.intermediate_size, config.num_attention_heads
)
for _ in range(config.num_hidden_layers)
]
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def _embed(self, x: mx.array) -> mx.array: self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def __call__(self, x: mx.array) -> mx.array:
batch_size = x.shape[0] batch_size = x.shape[0]
# Patchify using conv: # Patchify using conv:
# [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim]
@ -170,25 +255,48 @@ class CLIPVisionModel(nn.Module):
# [batch_size, num_patches + 1, embed_dim] # [batch_size, num_patches + 1, embed_dim]
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
# Add positional encoding # Add positional encoding
embeddings += self.position_embedding embeddings += self.position_embedding.weight
return embeddings return embeddings
def __call__(self, x: mx.array) -> CLIPVisionOutput:
x = self._embed(x)
x = self.pre_layernorm(x)
for l in self.layers: class ClipVisionModel(nn.Module):
"""Implements the vision encoder transformer from CLIP."""
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.embeddings = VisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
self.encoder = Encoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def __call__(
self,
x: mx.array,
output_hidden_states: Optional[bool] = None,
) -> CLIPVisionOutput:
x = self.embeddings(x)
x = self.pre_layrnorm(x)
encoder_states = (x,) if output_hidden_states else None
for l in self.encoder.layers:
x = l(x, mask=None) x = l(x, mask=None)
if output_hidden_states:
encoder_states = encoder_states + (x,)
# Extract <CLS> token embedding # Extract <CLS> token embedding
pooler_output = self.post_layernorm(x[:, 0, :]) pooler_output = self.post_layernorm(x[:, 0, :])
return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x) return CLIPVisionOutput(
pooler_output=pooler_output,
last_hidden_state=x,
hidden_states=encoder_states,
)
class CLIPModel(nn.Module): class CLIPModel(nn.Module):
def __init__(self, config: CLIPConfig): def __init__(self, config: CLIPConfig):
self.text_model = CLIPTextModel(config.text_config) self.text_model = ClipTextModel(config.text_config)
self.vision_model = CLIPVisionModel(config.vision_config) self.vision_model = ClipVisionModel(config.vision_config)
text_embed_dim = config.text_config.hidden_size text_embed_dim = config.text_config.hidden_size
vision_embed_dim = config.vision_config.hidden_size vision_embed_dim = config.vision_config.hidden_size
@ -259,6 +367,7 @@ class CLIPModel(nn.Module):
num_attention_heads=text_config["num_attention_heads"], num_attention_heads=text_config["num_attention_heads"],
max_position_embeddings=text_config["max_position_embeddings"], max_position_embeddings=text_config["max_position_embeddings"],
vocab_size=text_config["vocab_size"], vocab_size=text_config["vocab_size"],
layer_norm_eps=text_config["layer_norm_eps"],
) )
vision_config = config["vision_config"] vision_config = config["vision_config"]
@ -271,6 +380,7 @@ class CLIPModel(nn.Module):
num_channels=3, num_channels=3,
image_size=vision_config["image_size"], image_size=vision_config["image_size"],
patch_size=vision_config["patch_size"], patch_size=vision_config["patch_size"],
layer_norm_eps=vision_config["layer_norm_eps"],
) )
config = CLIPConfig( config = CLIPConfig(
@ -279,5 +389,31 @@ class CLIPModel(nn.Module):
projection_dim=config["projection_dim"], projection_dim=config["projection_dim"],
) )
model = CLIPModel(config) model = CLIPModel(config)
model.load_weights(str(path / "weights.npz")) weight_files = glob.glob(str(path / "*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {path}")
raise FileNotFoundError(f"No safetensors found in {path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
weights = model.sanitize(weights)
model.load_weights(list(weights.items()))
return model return model
@staticmethod
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if "position_ids" in k:
# Remove unused position_ids
continue
elif "patch_embedding.weight" in k:
# pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW]
# mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels]
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
else:
sanitized_weights[k] = v
return sanitized_weights

View File

@ -3,4 +3,4 @@ numpy
transformers transformers
torch torch
huggingface_hub huggingface_hub
Pillow Pillow

View File

@ -86,12 +86,14 @@ class TestCLIP(unittest.TestCase):
with torch.inference_mode(): with torch.inference_mode():
# Get expected # Get expected
x_tc = torch.tensor(x) x_tc = torch.tensor(x)
expected_out = self.hf_clip.vision_model(x_tc) expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True)
expected_last_hidden = expected_out.last_hidden_state.numpy() expected_last_hidden = expected_out.last_hidden_state.numpy()
expected_pooler_output = expected_out.pooler_output.numpy() expected_pooler_output = expected_out.pooler_output.numpy()
expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states]
# Test MLX vision encoder # Test MLX vision encoder
out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1))) out = self.mx_clip.vision_model(
mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True
)
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3 out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
@ -102,6 +104,10 @@ class TestCLIP(unittest.TestCase):
out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3 out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
), ),
) )
for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states):
self.assertTrue(
np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3),
)
def test_clip_model(self): def test_clip_model(self):
image_input = self.hf_image_proc( image_input = self.hf_image_proc(