diff --git a/clip/convert.py b/clip/convert.py index 5ce13e10..a646f93f 100644 --- a/clip/convert.py +++ b/clip/convert.py @@ -1,15 +1,68 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import json import shutil from pathlib import Path -from typing import Tuple +from typing import Any, Dict, Union import mlx.core as mx import torch from huggingface_hub import snapshot_download +def make_shards(weights: dict, max_file_size_gb: int = 5) -> list: + max_file_size_bytes = max_file_size_gb << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + if shard_size + v.nbytes > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += v.nbytes + shards.append(shard) + return shards + + +def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: + """Save model weights into specified directory.""" + if isinstance(save_path, str): + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + shards = make_shards(weights) + shards_count = len(shards) + shard_file_format = ( + "model-{:05d}-of-{:05d}.safetensors" + if shards_count > 1 + else "model.safetensors" + ) + + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + + for i, shard in enumerate(shards): + shard_name = shard_file_format.format(i + 1, shards_count) + shard_path = save_path / shard_name + + mx.save_safetensors(str(shard_path), shard) + + for weight_name in shard.keys(): + index_data["weight_map"][weight_name] = shard_name + + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + + with open(save_path / "model.safetensors.index.json", "w") as f: + json.dump( + index_data, + f, + indent=4, + ) + + def get_model_path(path_or_hf_repo: str) -> Path: model_path = Path(path_or_hf_repo) if not model_path.exists(): @@ -32,44 +85,6 @@ def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: return mx.array(a.numpy(), getattr(mx, dtype)) -def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]: - key = key.replace("embeddings.", "") - key = key.replace("encoder.", "") - key = key.replace("position_embedding.weight", "position_embedding") - - # Map attention layers - if "self_attn." in key: - key = key.replace("self_attn.", "attention.") - if "q_proj." in key: - key = key.replace("q_proj.", "query_proj.") - if "k_proj." in key: - key = key.replace("k_proj.", "key_proj.") - if "v_proj." in key: - key = key.replace("v_proj.", "value_proj.") - if "layer_norm1." in key: - key = key.replace("layer_norm1.", "ln1.") - if "layer_norm2." in key: - key = key.replace("layer_norm2.", "ln2.") - # Map ffn layers - if "mlp.fc1" in key: - key = key.replace("mlp.fc1", "linear1") - if "mlp.fc2" in key: - key = key.replace("mlp.fc2", "linear2") - # Fix layernorm typo - if "pre_layrnorm" in key: - # Fix typo in weights :) - key = key.replace("pre_layrnorm", "pre_layernorm") - if "patch_embedding.weight" in key: - # Initially, value: [out_channels, in_channels, kH, KW]. - # We want [out_channels, kH, KW, in_channels] - value = value.permute(0, 2, 3, 1) - return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", ""))) - - -def should_keep_weight(key: str): - return not ("position_ids" in key) - - if __name__ == "__main__": parser = argparse.ArgumentParser( description="Download and Convert (OpenAI) CLIP weights to MLX" @@ -86,7 +101,12 @@ if __name__ == "__main__": default="mlx_model", help="Path to save the MLX model.", ) - + parser.add_argument( + "--dtype", + help="The data type to save the converted model.", + type=str, + default="float32", + ) args = parser.parse_args() torch_path = get_model_path(args.hf_repo) @@ -96,10 +116,11 @@ if __name__ == "__main__": print("[INFO] Loading") torch_weights = torch.load(torch_path / "pytorch_model.bin") print("[INFO] Converting") - mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items()) - mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)} + mlx_weights = { + k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items() + } print("[INFO] Saving") - mx.savez(str(mlx_path / "weights.npz"), **mlx_weights) + save_weights(mlx_path, mlx_weights) for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]: shutil.copyfile( str(torch_path / f"{fn}"), diff --git a/clip/model.py b/clip/model.py index 407a5e6c..384fd59a 100644 --- a/clip/model.py +++ b/clip/model.py @@ -1,9 +1,12 @@ # Copyright © 2023-2024 Apple Inc. +import glob import json +import logging +import math from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Optional, Union import mlx.core as mx import mlx.nn as nn @@ -16,6 +19,7 @@ from mlx.utils import tree_flatten class CLIPVisionOutput: pooler_output: mx.array last_hidden_state: mx.array + hidden_states: Optional[mx.array] @dataclass @@ -41,6 +45,7 @@ class CLIPTextConfig: num_attention_heads: int max_position_embeddings: int vocab_size: int + layer_norm_eps: float @dataclass @@ -52,6 +57,7 @@ class CLIPVisionConfig: num_channels: int image_size: int patch_size: int + layer_norm_eps: float @dataclass @@ -75,50 +81,133 @@ def clip_loss(logits: mx.array) -> mx.array: return (caption_loss + image_loss) / 2.0 -class CLIPEncoderLayer(nn.TransformerEncoderLayer): +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + +class MLP(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + self.activation_fn = quick_gelu + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class EncoderLayer(nn.Module): """The transformer encoder layer from CLIP.""" - def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int): - super().__init__( - dims=hidden_dim, - mlp_dims=intermediate_dim, - num_heads=num_heads, - activation=quick_gelu, - norm_first=True, - ) + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.embed_dim = config.hidden_size # Add biases to the attention projections - self.attention = nn.MultiHeadAttention(hidden_dim, num_heads, bias=True) + self.self_attn = Attention( + config.hidden_size, config.num_attention_heads, bias=True + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y -class CLIPTextModel(nn.Module): +class TextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + def __call__(self, x: mx.array) -> mx.array: + embeddings = self.token_embedding(x) + embeddings += self.position_embedding.weight[: x.shape[1]] + return embeddings + + +class Encoder(nn.Module): + def __init__(self, config: CLIPTextConfig): + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class ClipTextModel(nn.Module): """Implements the text encoder transformer from CLIP.""" def __init__(self, config: CLIPTextConfig): super().__init__() - - self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embedding = mx.zeros( - (config.max_position_embeddings, config.hidden_size) - ) - self.layers = [ - CLIPEncoderLayer( - config.hidden_size, config.intermediate_size, config.num_attention_heads - ) - for _ in range(config.num_hidden_layers) - ] + self.embeddings = TextEmbeddings(config) + self.encoder = Encoder(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size) - def _embed(self, x: mx.array) -> mx.array: - embeddings = self.token_embedding(x) - embeddings += self.position_embedding[: x.shape[1]] - return embeddings - def __call__(self, x: mx.array) -> CLIPTextOutput: B, N = x.shape eot_tokens = mx.argmax(x, axis=-1) - x = self._embed(x) + x = self.embeddings(x) mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype) - for l in self.layers: + for l in self.encoder.layers: x = l(x, mask) last_hidden_state = self.final_layer_norm(x) pooler_output = last_hidden_state[mx.arange(B), eot_tokens] @@ -128,33 +217,29 @@ class CLIPTextModel(nn.Module): ) -class CLIPVisionModel(nn.Module): - """Implements the vision encoder transformer from CLIP.""" - +class VisionEmbeddings(nn.Module): def __init__(self, config: CLIPVisionConfig): super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size self.class_embedding = mx.zeros((config.hidden_size,)) + self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, - out_channels=config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, bias=False, ) - num_patches = (config.image_size // config.patch_size) ** 2 - num_positions = num_patches + 1 - self.position_embedding = mx.zeros((num_positions, config.hidden_size)) - self.pre_layernorm = nn.LayerNorm(config.hidden_size) - self.layers = [ - CLIPEncoderLayer( - config.hidden_size, config.intermediate_size, config.num_attention_heads - ) - for _ in range(config.num_hidden_layers) - ] - self.post_layernorm = nn.LayerNorm(config.hidden_size) - def _embed(self, x: mx.array) -> mx.array: + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] # Patchify using conv: # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] @@ -170,25 +255,48 @@ class CLIPVisionModel(nn.Module): # [batch_size, num_patches + 1, embed_dim] embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) # Add positional encoding - embeddings += self.position_embedding + embeddings += self.position_embedding.weight return embeddings - def __call__(self, x: mx.array) -> CLIPVisionOutput: - x = self._embed(x) - x = self.pre_layernorm(x) - for l in self.layers: +class ClipVisionModel(nn.Module): + """Implements the vision encoder transformer from CLIP.""" + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> CLIPVisionOutput: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) # Extract token embedding pooler_output = self.post_layernorm(x[:, 0, :]) - return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x) + return CLIPVisionOutput( + pooler_output=pooler_output, + last_hidden_state=x, + hidden_states=encoder_states, + ) class CLIPModel(nn.Module): def __init__(self, config: CLIPConfig): - self.text_model = CLIPTextModel(config.text_config) - self.vision_model = CLIPVisionModel(config.vision_config) + self.text_model = ClipTextModel(config.text_config) + self.vision_model = ClipVisionModel(config.vision_config) text_embed_dim = config.text_config.hidden_size vision_embed_dim = config.vision_config.hidden_size @@ -259,6 +367,7 @@ class CLIPModel(nn.Module): num_attention_heads=text_config["num_attention_heads"], max_position_embeddings=text_config["max_position_embeddings"], vocab_size=text_config["vocab_size"], + layer_norm_eps=text_config["layer_norm_eps"], ) vision_config = config["vision_config"] @@ -271,6 +380,7 @@ class CLIPModel(nn.Module): num_channels=3, image_size=vision_config["image_size"], patch_size=vision_config["patch_size"], + layer_norm_eps=vision_config["layer_norm_eps"], ) config = CLIPConfig( @@ -279,5 +389,31 @@ class CLIPModel(nn.Module): projection_dim=config["projection_dim"], ) model = CLIPModel(config) - model.load_weights(str(path / "weights.npz")) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {path}") + raise FileNotFoundError(f"No safetensors found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) return model + + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW] + # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights diff --git a/clip/requirements.txt b/clip/requirements.txt index a6f2f1ca..74f826ea 100644 --- a/clip/requirements.txt +++ b/clip/requirements.txt @@ -3,4 +3,4 @@ numpy transformers torch huggingface_hub -Pillow \ No newline at end of file +Pillow diff --git a/clip/test.py b/clip/test.py index d8a2f84d..63086953 100644 --- a/clip/test.py +++ b/clip/test.py @@ -86,12 +86,14 @@ class TestCLIP(unittest.TestCase): with torch.inference_mode(): # Get expected x_tc = torch.tensor(x) - expected_out = self.hf_clip.vision_model(x_tc) + expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True) expected_last_hidden = expected_out.last_hidden_state.numpy() expected_pooler_output = expected_out.pooler_output.numpy() - + expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states] # Test MLX vision encoder - out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1))) + out = self.mx_clip.vision_model( + mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True + ) self.assertTrue( np.allclose( out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3 @@ -102,6 +104,10 @@ class TestCLIP(unittest.TestCase): out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3 ), ) + for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states): + self.assertTrue( + np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3), + ) def test_clip_model(self): image_input = self.hf_image_proc(