diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 65903bce..3f445c0b 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,3 +10,4 @@ MLX Examples was developed with contributions from the following individuals: - Juarez Bochi: Added support for T5 models. - Sarthak Yadav: Added the `cifar` and `speechcommands` examples. - Shunta Saito: Added support for PLaMo models. +- Gabrijel Boduljak: Implemented `CLIP`. diff --git a/README.md b/README.md index 8d7f7030..653cbf01 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,15 @@ Some more useful examples are listed below. - Speech recognition with [OpenAI's Whisper](whisper). +### Multimodal models + +- Joint text and image embeddings with [CLIP](clip). + ### Other Models - Semi-supervised learning on graph-structured data with [GCN](gcn). +- Real NVP [normalizing flow](normalizing_flow) for density estimation and + sampling. ### Hugging Face diff --git a/clip/.gitignore b/clip/.gitignore new file mode 100644 index 00000000..ee603eb4 --- /dev/null +++ b/clip/.gitignore @@ -0,0 +1 @@ +mlx_model/ diff --git a/clip/README.md b/clip/README.md new file mode 100644 index 00000000..c7c18720 --- /dev/null +++ b/clip/README.md @@ -0,0 +1,76 @@ +# CLIP + +An example of OpenAI's CLIP in MLX. The CLIP (contrastive language-image +pre-training) model embeds images and text in the same space.[^1] + +### Setup + +Install the dependencies: + +```shell +pip install -r requirements.txt +``` + +Next, download a CLIP model from Hugging Face and convert it to MLX. The +default model is +[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32). + +``` +python convert.py +``` + +The script will by default download the model and configuration files to the +directory ``mlx_model/``. + +### Run + +You can use the CLIP model to embed images and text. + +```python +from PIL import Image +import clip + +model, tokenizer, img_processor = clip.load("mlx_model") +inputs = { + "input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]), + "pixel_values": img_processor( + [Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")] + ), +} +output = model(**inputs) + +# Get text and image embeddings: +text_embeds = output.text_embeds +image_embeds = output.image_embeds +``` + +Run the above example with `python clip.py`. + +To embed only images or only the text, pass only the ``input_ids`` or +``pixel_values``, respectively. + +This example re-implements minimal image preprocessing and tokenization to reduce +dependencies. For additional preprocessing functionality, you can use +``transformers``. The file `hf_preproc.py` has an example. + +MLX CLIP has been tested and works with the following Hugging Face repos: + +- [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) +- [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + +You can run the tests with: + +```shell +python test.py +``` + +To test new models, update the `MLX_PATH` and `HF_PATH` in `test.py`. + +### Attribution + +- `assets/cat.jpeg` is a "Cat" by London's, licensed under CC BY-SA 2.0. +- `assets/dog.jpeg` is a "Happy Dog" by tedmurphy, licensed under CC BY 2.0. + +[^1]: Refer to the original paper [Learning Transferable Visual Models From + Natural Language Supervision ](https://arxiv.org/abs/2103.00020) or [blog + post](https://openai.com/research/clip) diff --git a/clip/assets/cat.jpeg b/clip/assets/cat.jpeg new file mode 100644 index 00000000..f66e88e0 Binary files /dev/null and b/clip/assets/cat.jpeg differ diff --git a/clip/assets/dog.jpeg b/clip/assets/dog.jpeg new file mode 100644 index 00000000..78ca0fbe Binary files /dev/null and b/clip/assets/dog.jpeg differ diff --git a/clip/clip.py b/clip/clip.py new file mode 100644 index 00000000..1266cae2 --- /dev/null +++ b/clip/clip.py @@ -0,0 +1,31 @@ +from typing import Tuple + +from image_processor import CLIPImageProcessor +from model import CLIPModel +from tokenizer import CLIPTokenizer + + +def load(model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]: + model = CLIPModel.from_pretrained(model_dir) + tokenizer = CLIPTokenizer.from_pretrained(model_dir) + img_processor = CLIPImageProcessor.from_pretrained(model_dir) + return model, tokenizer, img_processor + + +if __name__ == "__main__": + from PIL import Image + + model, tokenizer, img_processor = load("mlx_model") + inputs = { + "input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]), + "pixel_values": img_processor( + [Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")] + ), + } + output = model(**inputs) + + # Get text and image embeddings: + text_embeds = output.text_embeds + image_embeds = output.image_embeds + print("Text embeddings shape:", text_embeds.shape) + print("Image embeddings shape:", image_embeds.shape) diff --git a/clip/convert.py b/clip/convert.py new file mode 100644 index 00000000..5ce13e10 --- /dev/null +++ b/clip/convert.py @@ -0,0 +1,107 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import shutil +from pathlib import Path +from typing import Tuple + +import mlx.core as mx +import torch +from huggingface_hub import snapshot_download + + +def get_model_path(path_or_hf_repo: str) -> Path: + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.bin", + "*.json", + "*.txt", + ], + ) + ) + return model_path + + +def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: + # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss + a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype)) + return mx.array(a.numpy(), getattr(mx, dtype)) + + +def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]: + key = key.replace("embeddings.", "") + key = key.replace("encoder.", "") + key = key.replace("position_embedding.weight", "position_embedding") + + # Map attention layers + if "self_attn." in key: + key = key.replace("self_attn.", "attention.") + if "q_proj." in key: + key = key.replace("q_proj.", "query_proj.") + if "k_proj." in key: + key = key.replace("k_proj.", "key_proj.") + if "v_proj." in key: + key = key.replace("v_proj.", "value_proj.") + if "layer_norm1." in key: + key = key.replace("layer_norm1.", "ln1.") + if "layer_norm2." in key: + key = key.replace("layer_norm2.", "ln2.") + # Map ffn layers + if "mlp.fc1" in key: + key = key.replace("mlp.fc1", "linear1") + if "mlp.fc2" in key: + key = key.replace("mlp.fc2", "linear2") + # Fix layernorm typo + if "pre_layrnorm" in key: + # Fix typo in weights :) + key = key.replace("pre_layrnorm", "pre_layernorm") + if "patch_embedding.weight" in key: + # Initially, value: [out_channels, in_channels, kH, KW]. + # We want [out_channels, kH, KW, in_channels] + value = value.permute(0, 2, 3, 1) + return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", ""))) + + +def should_keep_weight(key: str): + return not ("position_ids" in key) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Download and Convert (OpenAI) CLIP weights to MLX" + ) + parser.add_argument( + "--hf-repo", + type=str, + default="openai/clip-vit-base-patch32", + help="Hugging Face repository name.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="Path to save the MLX model.", + ) + + args = parser.parse_args() + + torch_path = get_model_path(args.hf_repo) + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + + print("[INFO] Loading") + torch_weights = torch.load(torch_path / "pytorch_model.bin") + print("[INFO] Converting") + mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items()) + mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)} + print("[INFO] Saving") + mx.savez(str(mlx_path / "weights.npz"), **mlx_weights) + for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]: + shutil.copyfile( + str(torch_path / f"{fn}"), + str(mlx_path / f"{fn}"), + ) diff --git a/clip/hf_preproc.py b/clip/hf_preproc.py new file mode 100644 index 00000000..ac0490d2 --- /dev/null +++ b/clip/hf_preproc.py @@ -0,0 +1,29 @@ +import mlx.core as mx +import transformers +from PIL import Image + +import clip + +hf_model = "openai/clip-vit-base-patch32" +mlx_model = "mlx_model" + +model, *_ = clip.load(mlx_model) +processor = transformers.CLIPProcessor.from_pretrained(hf_model) + +inputs = processor( + text=["a photo of a cat", "a photo of a dog"], + images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")], + return_tensors="np", +) + +out = model( + input_ids=mx.array(inputs.input_ids), + pixel_values=mx.array(inputs.pixel_values).transpose((0, 2, 3, 1)), + return_loss=True, +) + +print("text embeddings:") +print(out.text_embeds) +print("image embeddings:") +print(out.image_embeds) +print(f"CLIP loss: {out.loss.item():.3f}") diff --git a/clip/image_processor.py b/clip/image_processor.py new file mode 100644 index 00000000..832f5f6f --- /dev/null +++ b/clip/image_processor.py @@ -0,0 +1,93 @@ +# Copyright © 2023-2024 Apple Inc. + +import json +from pathlib import Path +from typing import List, Tuple + +import mlx.core as mx +import numpy as np +from PIL.Image import Image + + +class CLIPImageProcessor: + """ + A simple port of + https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. + """ + + def __init__( + self, + crop_size: int = 224, + do_center_crop: bool = True, + do_normalize: bool = True, + do_resize: bool = True, + image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], + image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], + size: int = 224, + **kwargs + ) -> None: + self.crop_size = crop_size + self.do_center_crop = do_center_crop + self.do_normalize = do_normalize + self.do_resize = do_resize + self.image_mean = mx.array(image_mean) + self.image_std = mx.array(image_std) + self.size = size + + def __call__(self, images: List[Image]) -> mx.array: + return mx.concatenate( + [self._preprocess(image)[None] for image in images], axis=0 + ) + + def _preprocess(self, image: Image) -> mx.array: + if self.do_resize: + image = resize(image, self.size) + if self.do_center_crop: + image = center_crop(image, (self.crop_size, self.crop_size)) + image = mx.array(np.array(image)) + image = rescale(image) + if self.do_normalize: + image = normalize(image, self.image_mean, self.image_std) + return image + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + with open(path / "preprocessor_config.json", encoding="utf-8") as f: + config = json.load(f) + return CLIPImageProcessor(**config) + + +def resize(image: Image, short_size: int) -> Image: + """ + Resize so small size to short_size + """ + width, height = image.size + short = min(width, height) + long = max(width, height) + if short == short_size: + return image + new_short = short_size + new_long = int(short_size * long / short) + new_size = (new_short, new_long) if width <= height else (new_long, new_short) + return image.resize(new_size) + + +def center_crop(image: Image, size: Tuple[int, int]) -> Image: + if size[0] % 2 != 0 or size[1] % 2 != 0: + raise ValueError("Only even crop sizes supported.") + original_width, original_height = image.size + crop_height, crop_width = size + top = (original_height - crop_height) // 2 + bottom = top + crop_height + left = (original_width - crop_width) // 2 + right = left + crop_width + return image.crop((left, top, right, bottom)) + + +def rescale(image: mx.array) -> mx.array: + return image.astype(mx.float32) * (1 / 255.0) + + +def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: + return (image - mean) / std diff --git a/clip/model.py b/clip/model.py new file mode 100644 index 00000000..407a5e6c --- /dev/null +++ b/clip/model.py @@ -0,0 +1,283 @@ +# Copyright © 2023-2024 Apple Inc. + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.core import linalg as LA +from mlx.nn.losses import cross_entropy +from mlx.utils import tree_flatten + + +@dataclass +class CLIPVisionOutput: + pooler_output: mx.array + last_hidden_state: mx.array + + +@dataclass +class CLIPTextOutput: + pooler_output: mx.array + last_hidden_state: mx.array + + +@dataclass +class CLIPModelOutput: + loss: Optional[mx.array] + text_embeds: Optional[mx.array] + image_embeds: Optional[mx.array] + text_model_output: CLIPTextOutput + vision_model_output: CLIPVisionOutput + + +@dataclass +class CLIPTextConfig: + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + max_position_embeddings: int + vocab_size: int + + +@dataclass +class CLIPVisionConfig: + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_channels: int + image_size: int + patch_size: int + + +@dataclass +class CLIPConfig: + text_config: CLIPTextConfig + vision_config: CLIPVisionConfig + projection_dim: int + + +def quick_gelu(x: mx.array) -> mx.array: + """ + A fast GELU approximation https://github.com/hendrycks/GELUs + """ + return x * mx.sigmoid(1.702 * x) + + +def clip_loss(logits: mx.array) -> mx.array: + N, M = logits.shape + caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean") + image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean") + return (caption_loss + image_loss) / 2.0 + + +class CLIPEncoderLayer(nn.TransformerEncoderLayer): + """The transformer encoder layer from CLIP.""" + + def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int): + super().__init__( + dims=hidden_dim, + mlp_dims=intermediate_dim, + num_heads=num_heads, + activation=quick_gelu, + norm_first=True, + ) + # Add biases to the attention projections + self.attention = nn.MultiHeadAttention(hidden_dim, num_heads, bias=True) + + +class CLIPTextModel(nn.Module): + """Implements the text encoder transformer from CLIP.""" + + def __init__(self, config: CLIPTextConfig): + super().__init__() + + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embedding = mx.zeros( + (config.max_position_embeddings, config.hidden_size) + ) + self.layers = [ + CLIPEncoderLayer( + config.hidden_size, config.intermediate_size, config.num_attention_heads + ) + for _ in range(config.num_hidden_layers) + ] + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + + def _embed(self, x: mx.array) -> mx.array: + embeddings = self.token_embedding(x) + embeddings += self.position_embedding[: x.shape[1]] + return embeddings + + def __call__(self, x: mx.array) -> CLIPTextOutput: + B, N = x.shape + eot_tokens = mx.argmax(x, axis=-1) + x = self._embed(x) + mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype) + for l in self.layers: + x = l(x, mask) + last_hidden_state = self.final_layer_norm(x) + pooler_output = last_hidden_state[mx.arange(B), eot_tokens] + + return CLIPTextOutput( + pooler_output=pooler_output, last_hidden_state=last_hidden_state + ) + + +class CLIPVisionModel(nn.Module): + """Implements the vision encoder transformer from CLIP.""" + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + + self.class_embedding = mx.zeros((config.hidden_size,)) + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + num_patches = (config.image_size // config.patch_size) ** 2 + num_positions = num_patches + 1 + self.position_embedding = mx.zeros((num_positions, config.hidden_size)) + self.pre_layernorm = nn.LayerNorm(config.hidden_size) + self.layers = [ + CLIPEncoderLayer( + config.hidden_size, config.intermediate_size, config.num_attention_heads + ) + for _ in range(config.num_hidden_layers) + ] + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def _embed(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + # Patchify using conv: + # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] + patch_embeddings = self.patch_embedding(x) + # [batch_size, num_patches, embed_dim] + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + # Prepend embeddings + # [batch_size, 1, embed_dim] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + # [batch_size, num_patches + 1, embed_dim] + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + # Add positional encoding + embeddings += self.position_embedding + return embeddings + + def __call__(self, x: mx.array) -> CLIPVisionOutput: + x = self._embed(x) + x = self.pre_layernorm(x) + + for l in self.layers: + x = l(x, mask=None) + + # Extract token embedding + pooler_output = self.post_layernorm(x[:, 0, :]) + return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x) + + +class CLIPModel(nn.Module): + def __init__(self, config: CLIPConfig): + self.text_model = CLIPTextModel(config.text_config) + self.vision_model = CLIPVisionModel(config.vision_config) + + text_embed_dim = config.text_config.hidden_size + vision_embed_dim = config.vision_config.hidden_size + projection_dim = config.projection_dim + + self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False) + self.text_projection = nn.Linear(text_embed_dim, projection_dim, bias=False) + self.logit_scale = mx.array(0.0) + + def get_text_features(self, x: mx.array) -> mx.array: + return self.text_projection(self.text_model(x).pooler_output) + + def get_image_features(self, x: mx.array) -> mx.array: + return self.visual_projection(self.vision_model(x).pooler_output) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + return_loss=False, + ) -> CLIPModelOutput: + if input_ids is not None: + text_model_output = self.text_model(input_ids) + text_embeds = self.text_projection(text_model_output.pooler_output) + text_embeds = text_embeds / LA.norm(text_embeds, axis=-1, keepdims=True) + else: + text_embeds = None + text_model_output = None + + if pixel_values is not None: + vision_model_output = self.vision_model(pixel_values) + image_embeds = self.visual_projection(vision_model_output.pooler_output) + image_embeds = image_embeds / LA.norm(image_embeds, axis=-1, keepdims=True) + else: + image_embeds = None + vision_model_output = None + + if return_loss and (input_ids is None or pixel_values is None): + raise ValueError("Must provide text and image inputs to compute loss.") + + if return_loss: + logit_scale = mx.exp(self.logit_scale) + logits = (text_embeds @ image_embeds.T) * logit_scale + loss = clip_loss(logits) + else: + loss = None + + return CLIPModelOutput( + loss=loss, + text_embeds=text_embeds, + image_embeds=image_embeds, + vision_model_output=vision_model_output, + text_model_output=text_model_output, + ) + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + with open(path / "config.json", "r") as fid: + config = json.load(fid) + + text_config = config["text_config"] + text_config = CLIPTextConfig( + num_hidden_layers=text_config["num_hidden_layers"], + hidden_size=text_config["hidden_size"], + intermediate_size=text_config["intermediate_size"], + num_attention_heads=text_config["num_attention_heads"], + max_position_embeddings=text_config["max_position_embeddings"], + vocab_size=text_config["vocab_size"], + ) + + vision_config = config["vision_config"] + + vision_config = CLIPVisionConfig( + num_hidden_layers=vision_config["num_hidden_layers"], + hidden_size=vision_config["hidden_size"], + intermediate_size=vision_config["intermediate_size"], + num_attention_heads=vision_config["num_attention_heads"], + num_channels=3, + image_size=vision_config["image_size"], + patch_size=vision_config["patch_size"], + ) + + config = CLIPConfig( + text_config=text_config, + vision_config=vision_config, + projection_dim=config["projection_dim"], + ) + model = CLIPModel(config) + model.load_weights(str(path / "weights.npz")) + return model diff --git a/clip/requirements.txt b/clip/requirements.txt new file mode 100644 index 00000000..a6f2f1ca --- /dev/null +++ b/clip/requirements.txt @@ -0,0 +1,6 @@ +mlx +numpy +transformers +torch +huggingface_hub +Pillow \ No newline at end of file diff --git a/clip/test.py b/clip/test.py new file mode 100644 index 00000000..d8a2f84d --- /dev/null +++ b/clip/test.py @@ -0,0 +1,136 @@ +import unittest + +import mlx.core as mx +import model +import numpy as np +import torch +import transformers +from image_processor import CLIPImageProcessor +from PIL import Image +from tokenizer import CLIPTokenizer +from transformers import AutoTokenizer +from transformers.image_processing_utils import ChannelDimension + +MLX_PATH = "mlx_model" +HF_PATH = "openai/clip-vit-base-patch32" + + +def load_mlx_models(path): + image_proc = CLIPImageProcessor.from_pretrained(path) + tokenizer = CLIPTokenizer.from_pretrained(path) + clip = model.CLIPModel.from_pretrained(path) + return image_proc, tokenizer, clip + + +def load_hf_models(path): + image_proc = transformers.CLIPImageProcessor.from_pretrained(path) + tokenizer = AutoTokenizer.from_pretrained(path) + clip = transformers.CLIPModel.from_pretrained(path) + return image_proc, tokenizer, clip + + +class TestCLIP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mx_image_proc, cls.mx_tokenizer, cls.mx_clip = load_mlx_models(MLX_PATH) + cls.hf_image_proc, cls.hf_tokenizer, cls.hf_clip = load_hf_models(HF_PATH) + + def test_image_processor(self): + image = Image.open("assets/cat.jpeg") + + mx_data = self.mx_image_proc([image]) + hf_data = mx.array( + np.array( + self.hf_image_proc([image], data_format=ChannelDimension.LAST)[ + "pixel_values" + ] + ) + ) + self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5)) + + def test_text_tokenizer(self): + texts = ["a photo of a cat", "a photo of a dog"] + for txt in texts: + self.assertTrue( + np.array_equal( + self.mx_tokenizer.tokenize(txt)[None, :], + self.hf_tokenizer(txt, return_tensors="np")["input_ids"], + ), + ) + + def test_text_encoder(self): + texts = ["a photo of a cat", "a photo of a dog"] + # Tokenize + hf_tokens = self.hf_tokenizer(texts, return_tensors="pt") + mx_tokens = self.mx_tokenizer(texts) + # Get expected + with torch.inference_mode(): + expected_out = self.hf_clip.text_model(**hf_tokens) + expected_last_hidden = expected_out.last_hidden_state.numpy() + expected_pooler_output = expected_out.pooler_output.numpy() + out = self.mx_clip.text_model(mx_tokens) + self.assertTrue( + np.allclose(out.last_hidden_state, expected_last_hidden, atol=1e-5) + ) + self.assertTrue( + np.allclose(out.pooler_output, expected_pooler_output, atol=1e-5) + ) + + def test_vision_encoder(self): + # Load and process test image + x = self.hf_image_proc( + images=[Image.open("assets/dog.jpeg")], return_tensors="np" + ).pixel_values + + # Infer with HuggingFace model + with torch.inference_mode(): + # Get expected + x_tc = torch.tensor(x) + expected_out = self.hf_clip.vision_model(x_tc) + expected_last_hidden = expected_out.last_hidden_state.numpy() + expected_pooler_output = expected_out.pooler_output.numpy() + + # Test MLX vision encoder + out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1))) + self.assertTrue( + np.allclose( + out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3 + ), + ) + self.assertTrue( + np.allclose( + out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3 + ), + ) + + def test_clip_model(self): + image_input = self.hf_image_proc( + images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")], + return_tensors="np", + )["pixel_values"] + text = ["a photo of a cat", "a photo of a dog"] + tokens = self.hf_tokenizer(text, return_tensors="np")["input_ids"] + with torch.inference_mode(): + expected_out = self.hf_clip( + input_ids=torch.tensor(tokens), + pixel_values=torch.tensor(image_input), + return_loss=True, + ) + + out = self.mx_clip( + input_ids=mx.array(tokens), + pixel_values=mx.array(image_input.transpose((0, 2, 3, 1))), + return_loss=True, + ) + + self.assertTrue( + np.allclose(out.text_embeds, expected_out.text_embeds, atol=1e-5) + ) + self.assertTrue( + np.allclose(out.image_embeds, expected_out.image_embeds, atol=1e-5) + ) + self.assertTrue(np.allclose(out.loss, expected_out.loss, atol=1e-5)) + + +if __name__ == "__main__": + unittest.main() diff --git a/clip/tokenizer.py b/clip/tokenizer.py new file mode 100644 index 00000000..87daaa3f --- /dev/null +++ b/clip/tokenizer.py @@ -0,0 +1,121 @@ +# Copyright © 2023-2024 Apple Inc. + +import json +from pathlib import Path +from typing import Any + +import mlx.core as mx +import regex + + +class CLIPTokenizer: + """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" + + def __init__(self, bpe_ranks, vocab): + self.bpe_ranks = bpe_ranks + self.vocab = vocab + self.pat = regex.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + regex.IGNORECASE, + ) + self._cache = {self.bos: self.bos, self.eos: self.eos} + + @property + def bos(self): + return "<|startoftext|>" + + @property + def bos_token(self): + return self.vocab[self.bos] + + @property + def eos(self): + return "<|endoftext|>" + + @property + def eos_token(self): + return self.vocab[self.eos] + + def bpe(self, text): + if text in self._cache: + return self._cache[text] + + unigrams = list(text[:-1]) + [text[-1] + ""] + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + if not unique_bigrams: + return unigrams + + # In every iteration try to merge the two most likely bigrams. If none + # was merged we are done. + # + # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_py + while unique_bigrams: + bigram = min( + unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) + ) + if bigram not in self.bpe_ranks: + break + + new_unigrams = [] + skip = False + for a, b in zip(unigrams, unigrams[1:]): + if skip: + skip = False + continue + + if (a, b) == bigram: + new_unigrams.append(a + b) + skip = True + + else: + new_unigrams.append(a) + + if not skip: + new_unigrams.append(b) + + unigrams = new_unigrams + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + self._cache[text] = unigrams + + return unigrams + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.tokenize(*args, **kwargs) + + def tokenize(self, text, prepend_bos=True, append_eos=True) -> mx.array: + if isinstance(text, list): + return mx.array([self.tokenize(t, prepend_bos, append_eos) for t in text]) + + # Lower case, cleanup, and split. Hugging Face does a much, + # more thorough job here but this should suffice for 95% of + # cases. + clean_text = regex.sub(r"\s+", " ", text.lower()) + tokens = regex.findall(self.pat, clean_text) + + # Split the tokens according to the byte-pair merge file + bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] + + # Map to token ids and return + tokens = [] + if prepend_bos: + tokens.append(self.bos_token) + tokens.extend(self.vocab[t] for t in bpe_tokens) + if append_eos: + tokens.append(self.eos_token) + return mx.array(tokens) + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + with open(path / "vocab.json", encoding="utf-8") as f: + vocab = json.load(f) + with open(path / "merges.txt", encoding="utf-8") as f: + bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + + bpe_merges = [tuple(m.split()) for m in bpe_merges] + bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) + + return CLIPTokenizer(bpe_ranks, vocab)