mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
LlaVA in MLX (#461)
* add: llava mlx first draft * add: weights comparision * add forward pass skeleton * update: now imports weights correctly * delete base * latest * adding config * fix: use config * add mlx config * feat: add image processor for llava processor * wip * feat: llava working example * chore: refactor generate script * chore: clean up * add: warning to user if no <image> token despite using one * add: __call__ to LlavaModel * add: call to LlavaModel * update fp * clean up var names * update: native GeLU * Cleanup * update generate and readme * remove todo comment * rearrange tests * fix example code * nits in README * update readme * nit in readme * nits in README * chore(llava): refactor image embedding merging logic * min mlx version * nits in readmes * fix cli prompt, some nits * updates, slight simplify --------- Co-authored-by: anchen <li.anchen.au@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
179
llava/llava.py
Normal file
179
llava/llava.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from language import LanguageModel, TextConfig
|
||||
from vision import VisionConfig, VisionModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlaVAConfig:
|
||||
text_config: TextConfig
|
||||
vision_config: VisionConfig
|
||||
ignore_index: int = -100
|
||||
image_token_index: int = 32000
|
||||
vision_feature_select_strategy: str = "default"
|
||||
vision_feature_layer: int = -2
|
||||
vocab_size: int = 32000
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlaVAConfig):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
)
|
||||
self.gelu = nn.GELU()
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.linear_1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlavaModel(nn.Module):
|
||||
def __init__(self, config: LlaVAConfig):
|
||||
self.config = config
|
||||
self.vision_tower = VisionModel(config.vision_config)
|
||||
self.language_model = LanguageModel(config.text_config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.vision_feature_layer = config.vision_feature_layer
|
||||
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Optional[mx.array] = None,
|
||||
pixel_values: Optional[mx.array] = None,
|
||||
):
|
||||
if pixel_values is None:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
# Get the input embeddings from the language model
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Get the ouptut hidden states from the vision model
|
||||
*_, hidden_states = self.vision_tower(
|
||||
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
|
||||
)
|
||||
|
||||
# Select the hidden states from the desired layer
|
||||
selected_image_feature = hidden_states[self.vision_feature_layer]
|
||||
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unexpected feature selection strategy: "
|
||||
f"{self.vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
# Pass image features through the multi-modal projector
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# Insert special image tokens in the input_ids
|
||||
final_inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids
|
||||
)
|
||||
return final_inputs_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids
|
||||
):
|
||||
image_token_index = self.config.image_token_index
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
|
||||
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
||||
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
||||
|
||||
if len(image_positions) != num_images:
|
||||
raise ValueError(
|
||||
f"The number of image tokens ({len(image_positions)}) does not "
|
||||
f" match the number of image inputs ({num_images})."
|
||||
)
|
||||
|
||||
text_segments = []
|
||||
start_idx = 0
|
||||
|
||||
for position in image_positions:
|
||||
text_segments.append(inputs_embeds[:, start_idx:position])
|
||||
start_idx = position + 1
|
||||
|
||||
image_embeddings = mx.split(image_features, image_features.shape[0])
|
||||
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
|
||||
final_embeddings += [inputs_embeds[:, start_idx:]]
|
||||
|
||||
# Create a final embedding of shape
|
||||
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||
return mx.concatenate(final_embeddings, axis=1)
|
||||
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
logits, cache = self.language_model(
|
||||
input_ids, cache=cache, inputs_embeds=input_embddings
|
||||
)
|
||||
return logits, cache
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path_or_hf_repo: str):
|
||||
path = Path(path_or_hf_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
model_config = LlaVAConfig.from_dict(model_config)
|
||||
|
||||
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
|
||||
model_config.text_config = TextConfig.from_dict(model_config.text_config)
|
||||
|
||||
model = LlavaModel(model_config)
|
||||
weight_files = glob.glob(str(path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
weights = VisionModel.sanitize(weights)
|
||||
weights = LanguageModel.sanitize(weights)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
return model
|
Reference in New Issue
Block a user