mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
169 lines
5.6 KiB
Python
169 lines
5.6 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import glob
|
|
import inspect
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
from huggingface_hub import snapshot_download
|
|
from language import LanguageModel, TextConfig
|
|
from vision import VisionConfig, VisionModel
|
|
|
|
|
|
@dataclass
|
|
class LlaVAConfig:
|
|
text_config: TextConfig
|
|
vision_config: VisionConfig
|
|
ignore_index: int = -100
|
|
image_token_index: int = 32000
|
|
vision_feature_select_strategy: str = "default"
|
|
vision_feature_layer: int = -2
|
|
vocab_size: int = 32000
|
|
|
|
@classmethod
|
|
def from_dict(cls, params):
|
|
return cls(
|
|
**{
|
|
k: v
|
|
for k, v in params.items()
|
|
if k in inspect.signature(cls).parameters
|
|
}
|
|
)
|
|
|
|
|
|
class LlavaMultiModalProjector(nn.Module):
|
|
def __init__(self, config: LlaVAConfig):
|
|
super().__init__()
|
|
self.linear_1 = nn.Linear(
|
|
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
|
|
)
|
|
self.gelu = nn.GELU()
|
|
self.linear_2 = nn.Linear(
|
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
|
|
)
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
x = self.linear_1(x)
|
|
x = self.gelu(x)
|
|
x = self.linear_2(x)
|
|
return x
|
|
|
|
|
|
class LlavaModel(nn.Module):
|
|
def __init__(self, config: LlaVAConfig):
|
|
self.config = config
|
|
self.vision_tower = VisionModel(config.vision_config)
|
|
self.language_model = LanguageModel(config.text_config)
|
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
|
self.vision_feature_layer = config.vision_feature_layer
|
|
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: Optional[mx.array] = None,
|
|
pixel_values: Optional[mx.array] = None,
|
|
):
|
|
# Get the input embeddings from the language model
|
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
if pixel_values is None:
|
|
return inputs_embeds
|
|
|
|
# Get the ouptut hidden states from the vision model
|
|
*_, hidden_states = self.vision_tower(
|
|
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
|
|
)
|
|
|
|
# Select the hidden states from the desired layer
|
|
selected_image_feature = hidden_states[self.vision_feature_layer]
|
|
|
|
if self.vision_feature_select_strategy == "default":
|
|
selected_image_feature = selected_image_feature[:, 1:]
|
|
elif self.vision_feature_select_strategy == "full":
|
|
selected_image_feature = selected_image_feature
|
|
else:
|
|
raise ValueError(
|
|
"Unexpected feature selection strategy: "
|
|
f"{self.vision_feature_select_strategy}"
|
|
)
|
|
|
|
# Pass image features through the multi-modal projector
|
|
image_features = self.multi_modal_projector(selected_image_feature)
|
|
|
|
# Insert special image tokens in the input_ids
|
|
final_inputs_embeds = self._merge_input_ids_with_image_features(
|
|
image_features, inputs_embeds, input_ids
|
|
)
|
|
return final_inputs_embeds
|
|
|
|
def _merge_input_ids_with_image_features(
|
|
self, image_features, inputs_embeds, input_ids
|
|
):
|
|
image_token_index = self.config.image_token_index
|
|
batch_size, num_image_patches, embed_dim = image_features.shape
|
|
|
|
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
|
image_positions = mx.array(
|
|
np.where(input_ids[0] == image_token_index)[0], mx.uint32
|
|
)
|
|
|
|
if len(image_positions) != num_image_patches:
|
|
raise ValueError(
|
|
f"The number of image tokens ({len(image_positions)}) does not "
|
|
f" match the number of image patches ({num_image_patches})."
|
|
)
|
|
|
|
inputs_embeds[0, image_positions] = image_features
|
|
return inputs_embeds
|
|
|
|
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
|
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
|
logits, cache = self.language_model(
|
|
input_ids, cache=cache, inputs_embeds=input_embddings
|
|
)
|
|
return logits, cache
|
|
|
|
@staticmethod
|
|
def from_pretrained(path_or_hf_repo: str):
|
|
path = Path(path_or_hf_repo)
|
|
if not path.exists():
|
|
path = Path(
|
|
snapshot_download(
|
|
repo_id=path_or_hf_repo,
|
|
allow_patterns=[
|
|
"*.json",
|
|
"*.safetensors",
|
|
"*.py",
|
|
"tokenizer.model",
|
|
"*.tiktoken",
|
|
],
|
|
)
|
|
)
|
|
|
|
with open(path / "config.json", "r") as f:
|
|
model_config = json.load(f)
|
|
|
|
model_config = LlaVAConfig.from_dict(model_config)
|
|
|
|
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
|
|
model_config.text_config = TextConfig.from_dict(model_config.text_config)
|
|
|
|
model = LlavaModel(model_config)
|
|
weight_files = glob.glob(str(path / "*.safetensors"))
|
|
if not weight_files:
|
|
raise FileNotFoundError(f"No safetensors found in {path}")
|
|
|
|
weights = {}
|
|
for wf in weight_files:
|
|
weights.update(mx.load(wf))
|
|
|
|
weights = VisionModel.sanitize(weights)
|
|
weights = LanguageModel.sanitize(weights)
|
|
|
|
model.load_weights(list(weights.items()))
|
|
return model
|