mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* probably approximatelly correct CLIPTextEncoder * implemented CLIPEncoderLayer as built-in nn.TransformerEncoderLayer * replaced embedding layer with simple matrix * implemented ViT * added ViT tests * fixed tests * added pooler_output for text * implemented complete CLIPModel * implemented init * implemented convert.py and from_pretrained * fixed some minor bugs and added the README.md * removed tokenizer unused comments * removed unused deps * updated ACKNOWLEDGEMENTS.md * Feat: Image Processor for CLIP (#1) @nkasmanoff: * clip image processor * added example usage * refactored image preprocessing * deleted unused image_config.py * removed preprocessing port * added dependency to mlx-data * fixed attribution and moved photos to assets * implemented a simple port of CLIPImageProcessor * review changes * PR review changes * renamed too verbose arg * updated README.md * nits in readme / conversion * simplify some stuff, remove unneeded inits * remove more init stuff * more simplify * make test a unit test * update main readme * readme nits --------- Co-authored-by: Noah Kasmanoff <nkasmanoff@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
30 lines
713 B
Python
30 lines
713 B
Python
import mlx.core as mx
|
|
import transformers
|
|
from PIL import Image
|
|
|
|
import clip
|
|
|
|
hf_model = "openai/clip-vit-base-patch32"
|
|
mlx_model = "mlx_model"
|
|
|
|
model, *_ = clip.load(mlx_model)
|
|
processor = transformers.CLIPProcessor.from_pretrained(hf_model)
|
|
|
|
inputs = processor(
|
|
text=["a photo of a cat", "a photo of a dog"],
|
|
images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")],
|
|
return_tensors="np",
|
|
)
|
|
|
|
out = model(
|
|
input_ids=mx.array(inputs.input_ids),
|
|
pixel_values=mx.array(inputs.pixel_values).transpose((0, 2, 3, 1)),
|
|
return_loss=True,
|
|
)
|
|
|
|
print("text embeddings:")
|
|
print(out.text_embeds)
|
|
print("image embeddings:")
|
|
print(out.image_embeds)
|
|
print(f"CLIP loss: {out.loss.item():.3f}")
|