CLIP (ViT) (#315)

* probably approximatelly correct CLIPTextEncoder

* implemented CLIPEncoderLayer as built-in nn.TransformerEncoderLayer

* replaced embedding layer with simple matrix

* implemented ViT

* added ViT tests

* fixed tests

* added pooler_output for text

* implemented complete CLIPModel

* implemented init

* implemented convert.py and from_pretrained

* fixed some minor bugs and added the README.md

* removed tokenizer unused comments

* removed unused deps

* updated ACKNOWLEDGEMENTS.md

* Feat: Image Processor for CLIP (#1)

@nkasmanoff:
* clip image processor
* added example usage

* refactored image preprocessing

* deleted unused image_config.py

* removed preprocessing port

* added dependency to mlx-data

* fixed attribution and moved photos to assets

* implemented a simple port of CLIPImageProcessor

* review changes

* PR review changes

* renamed too verbose arg

* updated README.md

* nits in readme / conversion

* simplify some stuff, remove unneeded inits

* remove more init stuff

* more simplify

* make test a unit test

* update main readme

* readme nits

---------

Co-authored-by: Noah Kasmanoff <nkasmanoff@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gabrijel Boduljak 2024-01-31 23:19:53 +01:00 committed by GitHub
parent ba3a9355d1
commit 94358219cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 890 additions and 0 deletions

View File

@ -10,3 +10,4 @@ MLX Examples was developed with contributions from the following individuals:
- Juarez Bochi: Added support for T5 models.
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples.
- Shunta Saito: Added support for PLaMo models.
- Gabrijel Boduljak: Implemented `CLIP`.

View File

@ -26,9 +26,15 @@ Some more useful examples are listed below.
- Speech recognition with [OpenAI's Whisper](whisper).
### Multimodal models
- Joint text and image embeddings with [CLIP](clip).
### Other Models
- Semi-supervised learning on graph-structured data with [GCN](gcn).
- Real NVP [normalizing flow](normalizing_flow) for density estimation and
sampling.
### Hugging Face

1
clip/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
mlx_model/

76
clip/README.md Normal file
View File

@ -0,0 +1,76 @@
# CLIP
An example of OpenAI's CLIP in MLX. The CLIP (contrastive language-image
pre-training) model embeds images and text in the same space.[^1]
### Setup
Install the dependencies:
```shell
pip install -r requirements.txt
```
Next, download a CLIP model from Hugging Face and convert it to MLX. The
default model is
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32).
```
python convert.py
```
The script will by default download the model and configuration files to the
directory ``mlx_model/``.
### Run
You can use the CLIP model to embed images and text.
```python
from PIL import Image
import clip
model, tokenizer, img_processor = clip.load("mlx_model")
inputs = {
"input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]),
"pixel_values": img_processor(
[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")]
),
}
output = model(**inputs)
# Get text and image embeddings:
text_embeds = output.text_embeds
image_embeds = output.image_embeds
```
Run the above example with `python clip.py`.
To embed only images or only the text, pass only the ``input_ids`` or
``pixel_values``, respectively.
This example re-implements minimal image preprocessing and tokenization to reduce
dependencies. For additional preprocessing functionality, you can use
``transformers``. The file `hf_preproc.py` has an example.
MLX CLIP has been tested and works with the following Hugging Face repos:
- [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)
- [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
You can run the tests with:
```shell
python test.py
```
To test new models, update the `MLX_PATH` and `HF_PATH` in `test.py`.
### Attribution
- `assets/cat.jpeg` is a "Cat" by London's, licensed under CC BY-SA 2.0.
- `assets/dog.jpeg` is a "Happy Dog" by tedmurphy, licensed under CC BY 2.0.
[^1]: Refer to the original paper [Learning Transferable Visual Models From
Natural Language Supervision ](https://arxiv.org/abs/2103.00020) or [blog
post](https://openai.com/research/clip)

BIN
clip/assets/cat.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

BIN
clip/assets/dog.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

31
clip/clip.py Normal file
View File

@ -0,0 +1,31 @@
from typing import Tuple
from image_processor import CLIPImageProcessor
from model import CLIPModel
from tokenizer import CLIPTokenizer
def load(model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]:
model = CLIPModel.from_pretrained(model_dir)
tokenizer = CLIPTokenizer.from_pretrained(model_dir)
img_processor = CLIPImageProcessor.from_pretrained(model_dir)
return model, tokenizer, img_processor
if __name__ == "__main__":
from PIL import Image
model, tokenizer, img_processor = load("mlx_model")
inputs = {
"input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]),
"pixel_values": img_processor(
[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")]
),
}
output = model(**inputs)
# Get text and image embeddings:
text_embeds = output.text_embeds
image_embeds = output.image_embeds
print("Text embeddings shape:", text_embeds.shape)
print("Image embeddings shape:", image_embeds.shape)

107
clip/convert.py Normal file
View File

@ -0,0 +1,107 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import shutil
from pathlib import Path
from typing import Tuple
import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=[
"*.bin",
"*.json",
"*.txt",
],
)
)
return model_path
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype))
return mx.array(a.numpy(), getattr(mx, dtype))
def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
key = key.replace("embeddings.", "")
key = key.replace("encoder.", "")
key = key.replace("position_embedding.weight", "position_embedding")
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
if "layer_norm1." in key:
key = key.replace("layer_norm1.", "ln1.")
if "layer_norm2." in key:
key = key.replace("layer_norm2.", "ln2.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
# Fix layernorm typo
if "pre_layrnorm" in key:
# Fix typo in weights :)
key = key.replace("pre_layrnorm", "pre_layernorm")
if "patch_embedding.weight" in key:
# Initially, value: [out_channels, in_channels, kH, KW].
# We want [out_channels, kH, KW, in_channels]
value = value.permute(0, 2, 3, 1)
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))
def should_keep_weight(key: str):
return not ("position_ids" in key)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download and Convert (OpenAI) CLIP weights to MLX"
)
parser.add_argument(
"--hf-repo",
type=str,
default="openai/clip-vit-base-patch32",
help="Hugging Face repository name.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
args = parser.parse_args()
torch_path = get_model_path(args.hf_repo)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin")
print("[INFO] Converting")
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items())
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
print("[INFO] Saving")
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
shutil.copyfile(
str(torch_path / f"{fn}"),
str(mlx_path / f"{fn}"),
)

29
clip/hf_preproc.py Normal file
View File

@ -0,0 +1,29 @@
import mlx.core as mx
import transformers
from PIL import Image
import clip
hf_model = "openai/clip-vit-base-patch32"
mlx_model = "mlx_model"
model, *_ = clip.load(mlx_model)
processor = transformers.CLIPProcessor.from_pretrained(hf_model)
inputs = processor(
text=["a photo of a cat", "a photo of a dog"],
images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")],
return_tensors="np",
)
out = model(
input_ids=mx.array(inputs.input_ids),
pixel_values=mx.array(inputs.pixel_values).transpose((0, 2, 3, 1)),
return_loss=True,
)
print("text embeddings:")
print(out.text_embeds)
print("image embeddings:")
print(out.image_embeds)
print(f"CLIP loss: {out.loss.item():.3f}")

93
clip/image_processor.py Normal file
View File

@ -0,0 +1,93 @@
# Copyright © 2023-2024 Apple Inc.
import json
from pathlib import Path
from typing import List, Tuple
import mlx.core as mx
import numpy as np
from PIL.Image import Image
class CLIPImageProcessor:
"""
A simple port of
https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py.
"""
def __init__(
self,
crop_size: int = 224,
do_center_crop: bool = True,
do_normalize: bool = True,
do_resize: bool = True,
image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073],
image_std: List[float] = [0.26862954, 0.26130258, 0.27577711],
size: int = 224,
**kwargs
) -> None:
self.crop_size = crop_size
self.do_center_crop = do_center_crop
self.do_normalize = do_normalize
self.do_resize = do_resize
self.image_mean = mx.array(image_mean)
self.image_std = mx.array(image_std)
self.size = size
def __call__(self, images: List[Image]) -> mx.array:
return mx.concatenate(
[self._preprocess(image)[None] for image in images], axis=0
)
def _preprocess(self, image: Image) -> mx.array:
if self.do_resize:
image = resize(image, self.size)
if self.do_center_crop:
image = center_crop(image, (self.crop_size, self.crop_size))
image = mx.array(np.array(image))
image = rescale(image)
if self.do_normalize:
image = normalize(image, self.image_mean, self.image_std)
return image
@staticmethod
def from_pretrained(path: str):
path = Path(path)
with open(path / "preprocessor_config.json", encoding="utf-8") as f:
config = json.load(f)
return CLIPImageProcessor(**config)
def resize(image: Image, short_size: int) -> Image:
"""
Resize so small size to short_size
"""
width, height = image.size
short = min(width, height)
long = max(width, height)
if short == short_size:
return image
new_short = short_size
new_long = int(short_size * long / short)
new_size = (new_short, new_long) if width <= height else (new_long, new_short)
return image.resize(new_size)
def center_crop(image: Image, size: Tuple[int, int]) -> Image:
if size[0] % 2 != 0 or size[1] % 2 != 0:
raise ValueError("Only even crop sizes supported.")
original_width, original_height = image.size
crop_height, crop_width = size
top = (original_height - crop_height) // 2
bottom = top + crop_height
left = (original_width - crop_width) // 2
right = left + crop_width
return image.crop((left, top, right, bottom))
def rescale(image: mx.array) -> mx.array:
return image.astype(mx.float32) * (1 / 255.0)
def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array:
return (image - mean) / std

283
clip/model.py Normal file
View File

@ -0,0 +1,283 @@
# Copyright © 2023-2024 Apple Inc.
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.core import linalg as LA
from mlx.nn.losses import cross_entropy
from mlx.utils import tree_flatten
@dataclass
class CLIPVisionOutput:
pooler_output: mx.array
last_hidden_state: mx.array
@dataclass
class CLIPTextOutput:
pooler_output: mx.array
last_hidden_state: mx.array
@dataclass
class CLIPModelOutput:
loss: Optional[mx.array]
text_embeds: Optional[mx.array]
image_embeds: Optional[mx.array]
text_model_output: CLIPTextOutput
vision_model_output: CLIPVisionOutput
@dataclass
class CLIPTextConfig:
num_hidden_layers: int
hidden_size: int
intermediate_size: int
num_attention_heads: int
max_position_embeddings: int
vocab_size: int
@dataclass
class CLIPVisionConfig:
num_hidden_layers: int
hidden_size: int
intermediate_size: int
num_attention_heads: int
num_channels: int
image_size: int
patch_size: int
@dataclass
class CLIPConfig:
text_config: CLIPTextConfig
vision_config: CLIPVisionConfig
projection_dim: int
def quick_gelu(x: mx.array) -> mx.array:
"""
A fast GELU approximation https://github.com/hendrycks/GELUs
"""
return x * mx.sigmoid(1.702 * x)
def clip_loss(logits: mx.array) -> mx.array:
N, M = logits.shape
caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean")
image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean")
return (caption_loss + image_loss) / 2.0
class CLIPEncoderLayer(nn.TransformerEncoderLayer):
"""The transformer encoder layer from CLIP."""
def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int):
super().__init__(
dims=hidden_dim,
mlp_dims=intermediate_dim,
num_heads=num_heads,
activation=quick_gelu,
norm_first=True,
)
# Add biases to the attention projections
self.attention = nn.MultiHeadAttention(hidden_dim, num_heads, bias=True)
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = mx.zeros(
(config.max_position_embeddings, config.hidden_size)
)
self.layers = [
CLIPEncoderLayer(
config.hidden_size, config.intermediate_size, config.num_attention_heads
)
for _ in range(config.num_hidden_layers)
]
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
def _embed(self, x: mx.array) -> mx.array:
embeddings = self.token_embedding(x)
embeddings += self.position_embedding[: x.shape[1]]
return embeddings
def __call__(self, x: mx.array) -> CLIPTextOutput:
B, N = x.shape
eot_tokens = mx.argmax(x, axis=-1)
x = self._embed(x)
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
last_hidden_state = self.final_layer_norm(x)
pooler_output = last_hidden_state[mx.arange(B), eot_tokens]
return CLIPTextOutput(
pooler_output=pooler_output, last_hidden_state=last_hidden_state
)
class CLIPVisionModel(nn.Module):
"""Implements the vision encoder transformer from CLIP."""
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.class_embedding = mx.zeros((config.hidden_size,))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
num_patches = (config.image_size // config.patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = mx.zeros((num_positions, config.hidden_size))
self.pre_layernorm = nn.LayerNorm(config.hidden_size)
self.layers = [
CLIPEncoderLayer(
config.hidden_size, config.intermediate_size, config.num_attention_heads
)
for _ in range(config.num_hidden_layers)
]
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def _embed(self, x: mx.array) -> mx.array:
batch_size = x.shape[0]
# Patchify using conv:
# [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim]
patch_embeddings = self.patch_embedding(x)
# [batch_size, num_patches, embed_dim]
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
embed_dim = patch_embeddings.shape[-1]
# Prepend <CLS> embeddings
# [batch_size, 1, embed_dim]
cls_embeddings = mx.broadcast_to(
self.class_embedding, (batch_size, 1, embed_dim)
)
# [batch_size, num_patches + 1, embed_dim]
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
# Add positional encoding
embeddings += self.position_embedding
return embeddings
def __call__(self, x: mx.array) -> CLIPVisionOutput:
x = self._embed(x)
x = self.pre_layernorm(x)
for l in self.layers:
x = l(x, mask=None)
# Extract <CLS> token embedding
pooler_output = self.post_layernorm(x[:, 0, :])
return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x)
class CLIPModel(nn.Module):
def __init__(self, config: CLIPConfig):
self.text_model = CLIPTextModel(config.text_config)
self.vision_model = CLIPVisionModel(config.vision_config)
text_embed_dim = config.text_config.hidden_size
vision_embed_dim = config.vision_config.hidden_size
projection_dim = config.projection_dim
self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
self.text_projection = nn.Linear(text_embed_dim, projection_dim, bias=False)
self.logit_scale = mx.array(0.0)
def get_text_features(self, x: mx.array) -> mx.array:
return self.text_projection(self.text_model(x).pooler_output)
def get_image_features(self, x: mx.array) -> mx.array:
return self.visual_projection(self.vision_model(x).pooler_output)
def __call__(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
return_loss=False,
) -> CLIPModelOutput:
if input_ids is not None:
text_model_output = self.text_model(input_ids)
text_embeds = self.text_projection(text_model_output.pooler_output)
text_embeds = text_embeds / LA.norm(text_embeds, axis=-1, keepdims=True)
else:
text_embeds = None
text_model_output = None
if pixel_values is not None:
vision_model_output = self.vision_model(pixel_values)
image_embeds = self.visual_projection(vision_model_output.pooler_output)
image_embeds = image_embeds / LA.norm(image_embeds, axis=-1, keepdims=True)
else:
image_embeds = None
vision_model_output = None
if return_loss and (input_ids is None or pixel_values is None):
raise ValueError("Must provide text and image inputs to compute loss.")
if return_loss:
logit_scale = mx.exp(self.logit_scale)
logits = (text_embeds @ image_embeds.T) * logit_scale
loss = clip_loss(logits)
else:
loss = None
return CLIPModelOutput(
loss=loss,
text_embeds=text_embeds,
image_embeds=image_embeds,
vision_model_output=vision_model_output,
text_model_output=text_model_output,
)
@staticmethod
def from_pretrained(path: str):
path = Path(path)
with open(path / "config.json", "r") as fid:
config = json.load(fid)
text_config = config["text_config"]
text_config = CLIPTextConfig(
num_hidden_layers=text_config["num_hidden_layers"],
hidden_size=text_config["hidden_size"],
intermediate_size=text_config["intermediate_size"],
num_attention_heads=text_config["num_attention_heads"],
max_position_embeddings=text_config["max_position_embeddings"],
vocab_size=text_config["vocab_size"],
)
vision_config = config["vision_config"]
vision_config = CLIPVisionConfig(
num_hidden_layers=vision_config["num_hidden_layers"],
hidden_size=vision_config["hidden_size"],
intermediate_size=vision_config["intermediate_size"],
num_attention_heads=vision_config["num_attention_heads"],
num_channels=3,
image_size=vision_config["image_size"],
patch_size=vision_config["patch_size"],
)
config = CLIPConfig(
text_config=text_config,
vision_config=vision_config,
projection_dim=config["projection_dim"],
)
model = CLIPModel(config)
model.load_weights(str(path / "weights.npz"))
return model

6
clip/requirements.txt Normal file
View File

@ -0,0 +1,6 @@
mlx
numpy
transformers
torch
huggingface_hub
Pillow

136
clip/test.py Normal file
View File

@ -0,0 +1,136 @@
import unittest
import mlx.core as mx
import model
import numpy as np
import torch
import transformers
from image_processor import CLIPImageProcessor
from PIL import Image
from tokenizer import CLIPTokenizer
from transformers import AutoTokenizer
from transformers.image_processing_utils import ChannelDimension
MLX_PATH = "mlx_model"
HF_PATH = "openai/clip-vit-base-patch32"
def load_mlx_models(path):
image_proc = CLIPImageProcessor.from_pretrained(path)
tokenizer = CLIPTokenizer.from_pretrained(path)
clip = model.CLIPModel.from_pretrained(path)
return image_proc, tokenizer, clip
def load_hf_models(path):
image_proc = transformers.CLIPImageProcessor.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
clip = transformers.CLIPModel.from_pretrained(path)
return image_proc, tokenizer, clip
class TestCLIP(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mx_image_proc, cls.mx_tokenizer, cls.mx_clip = load_mlx_models(MLX_PATH)
cls.hf_image_proc, cls.hf_tokenizer, cls.hf_clip = load_hf_models(HF_PATH)
def test_image_processor(self):
image = Image.open("assets/cat.jpeg")
mx_data = self.mx_image_proc([image])
hf_data = mx.array(
np.array(
self.hf_image_proc([image], data_format=ChannelDimension.LAST)[
"pixel_values"
]
)
)
self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5))
def test_text_tokenizer(self):
texts = ["a photo of a cat", "a photo of a dog"]
for txt in texts:
self.assertTrue(
np.array_equal(
self.mx_tokenizer.tokenize(txt)[None, :],
self.hf_tokenizer(txt, return_tensors="np")["input_ids"],
),
)
def test_text_encoder(self):
texts = ["a photo of a cat", "a photo of a dog"]
# Tokenize
hf_tokens = self.hf_tokenizer(texts, return_tensors="pt")
mx_tokens = self.mx_tokenizer(texts)
# Get expected
with torch.inference_mode():
expected_out = self.hf_clip.text_model(**hf_tokens)
expected_last_hidden = expected_out.last_hidden_state.numpy()
expected_pooler_output = expected_out.pooler_output.numpy()
out = self.mx_clip.text_model(mx_tokens)
self.assertTrue(
np.allclose(out.last_hidden_state, expected_last_hidden, atol=1e-5)
)
self.assertTrue(
np.allclose(out.pooler_output, expected_pooler_output, atol=1e-5)
)
def test_vision_encoder(self):
# Load and process test image
x = self.hf_image_proc(
images=[Image.open("assets/dog.jpeg")], return_tensors="np"
).pixel_values
# Infer with HuggingFace model
with torch.inference_mode():
# Get expected
x_tc = torch.tensor(x)
expected_out = self.hf_clip.vision_model(x_tc)
expected_last_hidden = expected_out.last_hidden_state.numpy()
expected_pooler_output = expected_out.pooler_output.numpy()
# Test MLX vision encoder
out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1)))
self.assertTrue(
np.allclose(
out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
),
)
self.assertTrue(
np.allclose(
out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
),
)
def test_clip_model(self):
image_input = self.hf_image_proc(
images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")],
return_tensors="np",
)["pixel_values"]
text = ["a photo of a cat", "a photo of a dog"]
tokens = self.hf_tokenizer(text, return_tensors="np")["input_ids"]
with torch.inference_mode():
expected_out = self.hf_clip(
input_ids=torch.tensor(tokens),
pixel_values=torch.tensor(image_input),
return_loss=True,
)
out = self.mx_clip(
input_ids=mx.array(tokens),
pixel_values=mx.array(image_input.transpose((0, 2, 3, 1))),
return_loss=True,
)
self.assertTrue(
np.allclose(out.text_embeds, expected_out.text_embeds, atol=1e-5)
)
self.assertTrue(
np.allclose(out.image_embeds, expected_out.image_embeds, atol=1e-5)
)
self.assertTrue(np.allclose(out.loss, expected_out.loss, atol=1e-5))
if __name__ == "__main__":
unittest.main()

121
clip/tokenizer.py Normal file
View File

@ -0,0 +1,121 @@
# Copyright © 2023-2024 Apple Inc.
import json
from pathlib import Path
from typing import Any
import mlx.core as mx
import regex
class CLIPTokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.tokenize(*args, **kwargs)
def tokenize(self, text, prepend_bos=True, append_eos=True) -> mx.array:
if isinstance(text, list):
return mx.array([self.tokenize(t, prepend_bos, append_eos) for t in text])
# Lower case, cleanup, and split. Hugging Face does a much,
# more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = []
if prepend_bos:
tokens.append(self.bos_token)
tokens.extend(self.vocab[t] for t in bpe_tokens)
if append_eos:
tokens.append(self.eos_token)
return mx.array(tokens)
@staticmethod
def from_pretrained(path: str):
path = Path(path)
with open(path / "vocab.json", encoding="utf-8") as f:
vocab = json.load(f)
with open(path / "merges.txt", encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return CLIPTokenizer(bpe_ranks, vocab)