2024-02-01 06:19:53 +08:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import model
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import transformers
|
|
|
|
from image_processor import CLIPImageProcessor
|
|
|
|
from PIL import Image
|
|
|
|
from tokenizer import CLIPTokenizer
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
from transformers.image_processing_utils import ChannelDimension
|
|
|
|
|
|
|
|
MLX_PATH = "mlx_model"
|
|
|
|
HF_PATH = "openai/clip-vit-base-patch32"
|
|
|
|
|
|
|
|
|
|
|
|
def load_mlx_models(path):
|
|
|
|
image_proc = CLIPImageProcessor.from_pretrained(path)
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(path)
|
|
|
|
clip = model.CLIPModel.from_pretrained(path)
|
|
|
|
return image_proc, tokenizer, clip
|
|
|
|
|
|
|
|
|
|
|
|
def load_hf_models(path):
|
|
|
|
image_proc = transformers.CLIPImageProcessor.from_pretrained(path)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(path)
|
|
|
|
clip = transformers.CLIPModel.from_pretrained(path)
|
|
|
|
return image_proc, tokenizer, clip
|
|
|
|
|
|
|
|
|
|
|
|
class TestCLIP(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
cls.mx_image_proc, cls.mx_tokenizer, cls.mx_clip = load_mlx_models(MLX_PATH)
|
|
|
|
cls.hf_image_proc, cls.hf_tokenizer, cls.hf_clip = load_hf_models(HF_PATH)
|
|
|
|
|
|
|
|
def test_image_processor(self):
|
|
|
|
image = Image.open("assets/cat.jpeg")
|
|
|
|
|
|
|
|
mx_data = self.mx_image_proc([image])
|
|
|
|
hf_data = mx.array(
|
|
|
|
np.array(
|
|
|
|
self.hf_image_proc([image], data_format=ChannelDimension.LAST)[
|
|
|
|
"pixel_values"
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5))
|
|
|
|
|
|
|
|
def test_text_tokenizer(self):
|
|
|
|
texts = ["a photo of a cat", "a photo of a dog"]
|
|
|
|
for txt in texts:
|
|
|
|
self.assertTrue(
|
|
|
|
np.array_equal(
|
|
|
|
self.mx_tokenizer.tokenize(txt)[None, :],
|
|
|
|
self.hf_tokenizer(txt, return_tensors="np")["input_ids"],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_text_encoder(self):
|
|
|
|
texts = ["a photo of a cat", "a photo of a dog"]
|
|
|
|
# Tokenize
|
|
|
|
hf_tokens = self.hf_tokenizer(texts, return_tensors="pt")
|
|
|
|
mx_tokens = self.mx_tokenizer(texts)
|
|
|
|
# Get expected
|
|
|
|
with torch.inference_mode():
|
|
|
|
expected_out = self.hf_clip.text_model(**hf_tokens)
|
|
|
|
expected_last_hidden = expected_out.last_hidden_state.numpy()
|
|
|
|
expected_pooler_output = expected_out.pooler_output.numpy()
|
|
|
|
out = self.mx_clip.text_model(mx_tokens)
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(out.last_hidden_state, expected_last_hidden, atol=1e-5)
|
|
|
|
)
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(out.pooler_output, expected_pooler_output, atol=1e-5)
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_vision_encoder(self):
|
|
|
|
# Load and process test image
|
|
|
|
x = self.hf_image_proc(
|
|
|
|
images=[Image.open("assets/dog.jpeg")], return_tensors="np"
|
|
|
|
).pixel_values
|
|
|
|
|
|
|
|
# Infer with HuggingFace model
|
|
|
|
with torch.inference_mode():
|
|
|
|
# Get expected
|
|
|
|
x_tc = torch.tensor(x)
|
2024-02-23 22:49:53 +08:00
|
|
|
expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True)
|
2024-02-01 06:19:53 +08:00
|
|
|
expected_last_hidden = expected_out.last_hidden_state.numpy()
|
|
|
|
expected_pooler_output = expected_out.pooler_output.numpy()
|
2024-02-23 22:49:53 +08:00
|
|
|
expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states]
|
2024-02-01 06:19:53 +08:00
|
|
|
# Test MLX vision encoder
|
2024-02-23 22:49:53 +08:00
|
|
|
out = self.mx_clip.vision_model(
|
|
|
|
mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True
|
|
|
|
)
|
2024-02-01 06:19:53 +08:00
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(
|
|
|
|
out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(
|
|
|
|
out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
|
|
|
|
),
|
|
|
|
)
|
2024-02-23 22:49:53 +08:00
|
|
|
for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states):
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3),
|
|
|
|
)
|
2024-02-01 06:19:53 +08:00
|
|
|
|
|
|
|
def test_clip_model(self):
|
|
|
|
image_input = self.hf_image_proc(
|
|
|
|
images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")],
|
|
|
|
return_tensors="np",
|
|
|
|
)["pixel_values"]
|
|
|
|
text = ["a photo of a cat", "a photo of a dog"]
|
|
|
|
tokens = self.hf_tokenizer(text, return_tensors="np")["input_ids"]
|
|
|
|
with torch.inference_mode():
|
|
|
|
expected_out = self.hf_clip(
|
|
|
|
input_ids=torch.tensor(tokens),
|
|
|
|
pixel_values=torch.tensor(image_input),
|
|
|
|
return_loss=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
out = self.mx_clip(
|
|
|
|
input_ids=mx.array(tokens),
|
|
|
|
pixel_values=mx.array(image_input.transpose((0, 2, 3, 1))),
|
|
|
|
return_loss=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(out.text_embeds, expected_out.text_embeds, atol=1e-5)
|
|
|
|
)
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(out.image_embeds, expected_out.image_embeds, atol=1e-5)
|
|
|
|
)
|
|
|
|
self.assertTrue(np.allclose(out.loss, expected_out.loss, atol=1e-5))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|