mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* add: llava mlx first draft * add: weights comparision * add forward pass skeleton * update: now imports weights correctly * delete base * latest * adding config * fix: use config * add mlx config * feat: add image processor for llava processor * wip * feat: llava working example * chore: refactor generate script * chore: clean up * add: warning to user if no <image> token despite using one * add: __call__ to LlavaModel * add: call to LlavaModel * update fp * clean up var names * update: native GeLU * Cleanup * update generate and readme * remove todo comment * rearrange tests * fix example code * nits in README * update readme * nit in readme * nits in README * chore(llava): refactor image embedding merging logic * min mlx version * nits in readmes * fix cli prompt, some nits * updates, slight simplify --------- Co-authored-by: anchen <li.anchen.au@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
163 lines
5.4 KiB
Python
163 lines
5.4 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
import requests
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
|
|
from llava import LlavaModel
|
|
|
|
MODEL_PATH = "llava-hf/llava-1.5-7b-hf"
|
|
PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
|
|
IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
|
|
|
|
def load_mlx_models(path):
|
|
model = LlavaModel.from_pretrained(path)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def load_hf_models(path):
|
|
model = LlavaForConditionalGeneration.from_pretrained(path)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
class TestVisionTower(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.mx_llava = load_mlx_models(MODEL_PATH)
|
|
cls.hf_llava = load_hf_models(MODEL_PATH)
|
|
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
|
|
|
|
def test_image_features(self):
|
|
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
|
vision_feature_layer = -2
|
|
with torch.no_grad():
|
|
pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[
|
|
"pixel_values"
|
|
]
|
|
|
|
hf_pixel_values = pixel_values
|
|
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
|
|
|
|
_, _, hidden_states = self.mx_llava.vision_tower(
|
|
mx_pixel_values,
|
|
output_hidden_states=True,
|
|
)
|
|
|
|
mx_elected_image_feature = hidden_states[vision_feature_layer]
|
|
mx_image_features = self.mx_llava.multi_modal_projector(
|
|
mx_elected_image_feature
|
|
)
|
|
|
|
hf_image_outputs = self.hf_llava.vision_tower(
|
|
hf_pixel_values, output_hidden_states=True
|
|
)
|
|
hf_elected_image_feature = hf_image_outputs.hidden_states[
|
|
vision_feature_layer
|
|
]
|
|
hf_image_features = self.hf_llava.multi_modal_projector(
|
|
hf_elected_image_feature
|
|
)
|
|
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
mx_image_features,
|
|
mx.array(hf_image_features.numpy()),
|
|
atol=1e-2,
|
|
)
|
|
)
|
|
|
|
|
|
class TestLlava(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.mx_llava = load_mlx_models(MODEL_PATH)
|
|
cls.hf_llava = load_hf_models(MODEL_PATH)
|
|
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
|
|
|
|
def test_merge_input_ids_with_image_features(self):
|
|
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
|
vision_feature_layer = -2
|
|
with torch.no_grad():
|
|
values = self.proc(PROMPT, raw_image, return_tensors="pt")
|
|
pixel_values = values["pixel_values"]
|
|
input_ids = values["input_ids"]
|
|
|
|
hf_pixel_values = pixel_values
|
|
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
|
|
|
|
_, _, hidden_states = self.mx_llava.vision_tower(
|
|
mx_pixel_values,
|
|
output_hidden_states=True,
|
|
)
|
|
mx_input_ids = mx.array(input_ids.numpy())
|
|
mx_elected_image_feature = hidden_states[vision_feature_layer]
|
|
mx_image_features = self.mx_llava.multi_modal_projector(
|
|
mx_elected_image_feature
|
|
)
|
|
mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens(
|
|
mx_input_ids
|
|
)
|
|
mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features(
|
|
mx_image_features, mx_inputs_embeds, mx_input_ids
|
|
)
|
|
|
|
hf_image_outputs = self.hf_llava.vision_tower(
|
|
hf_pixel_values, output_hidden_states=True
|
|
)
|
|
hf_elected_image_feature = hf_image_outputs.hidden_states[
|
|
vision_feature_layer
|
|
]
|
|
hf_image_features = self.hf_llava.multi_modal_projector(
|
|
hf_elected_image_feature
|
|
)
|
|
hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids)
|
|
hf_final_embedding, _, _, _ = (
|
|
self.hf_llava._merge_input_ids_with_image_features(
|
|
hf_image_features,
|
|
hf_inputs_embeds,
|
|
input_ids,
|
|
attention_mask=input_ids,
|
|
labels=torch.ones_like(input_ids),
|
|
)
|
|
)
|
|
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
mx_final_embedding,
|
|
mx.array(hf_final_embedding.numpy()),
|
|
atol=1e-1,
|
|
)
|
|
)
|
|
|
|
def test_generated_tokens(self):
|
|
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
|
with torch.no_grad():
|
|
hf_inputs = self.proc(PROMPT, raw_image, return_tensors="pt")
|
|
hf_outputs = self.hf_llava(**hf_inputs)
|
|
hf_logits = hf_outputs.logits
|
|
|
|
mx_inputs = self.proc(PROMPT, raw_image, return_tensors="np")
|
|
pixel_values = mx.array(mx_inputs["pixel_values"])
|
|
input_ids = mx.array(mx_inputs["input_ids"])
|
|
|
|
mx_logits, _ = self.mx_llava(input_ids, pixel_values)
|
|
|
|
self.assertTrue(
|
|
mx.allclose(
|
|
mx_logits[:, -1, :].argmax(axis=-1),
|
|
mx.array(hf_logits.numpy())[:, -1, :].argmax(axis=-1),
|
|
atol=1e-2,
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|