mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
LlaVA in MLX (#461)
* add: llava mlx first draft * add: weights comparision * add forward pass skeleton * update: now imports weights correctly * delete base * latest * adding config * fix: use config * add mlx config * feat: add image processor for llava processor * wip * feat: llava working example * chore: refactor generate script * chore: clean up * add: warning to user if no <image> token despite using one * add: __call__ to LlavaModel * add: call to LlavaModel * update fp * clean up var names * update: native GeLU * Cleanup * update generate and readme * remove todo comment * rearrange tests * fix example code * nits in README * update readme * nit in readme * nits in README * chore(llava): refactor image embedding merging logic * min mlx version * nits in readmes * fix cli prompt, some nits * updates, slight simplify --------- Co-authored-by: anchen <li.anchen.au@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
162
llava/test.py
Normal file
162
llava/test.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
|
||||
from llava import LlavaModel
|
||||
|
||||
MODEL_PATH = "llava-hf/llava-1.5-7b-hf"
|
||||
PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
|
||||
IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
|
||||
|
||||
def load_mlx_models(path):
|
||||
model = LlavaModel.from_pretrained(path)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_hf_models(path):
|
||||
model = LlavaForConditionalGeneration.from_pretrained(path)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class TestVisionTower(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mx_llava = load_mlx_models(MODEL_PATH)
|
||||
cls.hf_llava = load_hf_models(MODEL_PATH)
|
||||
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
def test_image_features(self):
|
||||
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
||||
vision_feature_layer = -2
|
||||
with torch.no_grad():
|
||||
pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[
|
||||
"pixel_values"
|
||||
]
|
||||
|
||||
hf_pixel_values = pixel_values
|
||||
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
|
||||
|
||||
_, _, hidden_states = self.mx_llava.vision_tower(
|
||||
mx_pixel_values,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
mx_elected_image_feature = hidden_states[vision_feature_layer]
|
||||
mx_image_features = self.mx_llava.multi_modal_projector(
|
||||
mx_elected_image_feature
|
||||
)
|
||||
|
||||
hf_image_outputs = self.hf_llava.vision_tower(
|
||||
hf_pixel_values, output_hidden_states=True
|
||||
)
|
||||
hf_elected_image_feature = hf_image_outputs.hidden_states[
|
||||
vision_feature_layer
|
||||
]
|
||||
hf_image_features = self.hf_llava.multi_modal_projector(
|
||||
hf_elected_image_feature
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx_image_features,
|
||||
mx.array(hf_image_features.numpy()),
|
||||
atol=1e-2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestLlava(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mx_llava = load_mlx_models(MODEL_PATH)
|
||||
cls.hf_llava = load_hf_models(MODEL_PATH)
|
||||
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
def test_merge_input_ids_with_image_features(self):
|
||||
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
||||
vision_feature_layer = -2
|
||||
with torch.no_grad():
|
||||
values = self.proc(PROMPT, raw_image, return_tensors="pt")
|
||||
pixel_values = values["pixel_values"]
|
||||
input_ids = values["input_ids"]
|
||||
|
||||
hf_pixel_values = pixel_values
|
||||
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
|
||||
|
||||
_, _, hidden_states = self.mx_llava.vision_tower(
|
||||
mx_pixel_values,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
mx_input_ids = mx.array(input_ids.numpy())
|
||||
mx_elected_image_feature = hidden_states[vision_feature_layer]
|
||||
mx_image_features = self.mx_llava.multi_modal_projector(
|
||||
mx_elected_image_feature
|
||||
)
|
||||
mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens(
|
||||
mx_input_ids
|
||||
)
|
||||
mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features(
|
||||
mx_image_features, mx_inputs_embeds, mx_input_ids
|
||||
)
|
||||
|
||||
hf_image_outputs = self.hf_llava.vision_tower(
|
||||
hf_pixel_values, output_hidden_states=True
|
||||
)
|
||||
hf_elected_image_feature = hf_image_outputs.hidden_states[
|
||||
vision_feature_layer
|
||||
]
|
||||
hf_image_features = self.hf_llava.multi_modal_projector(
|
||||
hf_elected_image_feature
|
||||
)
|
||||
hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids)
|
||||
hf_final_embedding, _, _, _ = (
|
||||
self.hf_llava._merge_input_ids_with_image_features(
|
||||
hf_image_features,
|
||||
hf_inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask=input_ids,
|
||||
labels=torch.ones_like(input_ids),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx_final_embedding,
|
||||
mx.array(hf_final_embedding.numpy()),
|
||||
atol=1e-1,
|
||||
)
|
||||
)
|
||||
|
||||
def test_generated_tokens(self):
|
||||
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
||||
with torch.no_grad():
|
||||
hf_inputs = self.proc(PROMPT, raw_image, return_tensors="pt")
|
||||
hf_outputs = self.hf_llava(**hf_inputs)
|
||||
hf_logits = hf_outputs.logits
|
||||
|
||||
mx_inputs = self.proc(PROMPT, raw_image, return_tensors="np")
|
||||
pixel_values = mx.array(mx_inputs["pixel_values"])
|
||||
input_ids = mx.array(mx_inputs["input_ids"])
|
||||
|
||||
mx_logits, _ = self.mx_llava(input_ids, pixel_values)
|
||||
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx_logits[:, -1, :].argmax(axis=-1),
|
||||
mx.array(hf_logits.numpy())[:, -1, :].argmax(axis=-1),
|
||||
atol=1e-2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user