# Copyright © 2024 Apple Inc. import glob import inspect import json from dataclasses import dataclass from pathlib import Path from typing import Optional import mlx.core as mx import mlx.nn as nn import numpy as np from huggingface_hub import snapshot_download from language import LanguageModel, TextConfig from vision import VisionConfig, VisionModel @dataclass class LlaVAConfig: text_config: TextConfig vision_config: VisionConfig ignore_index: int = -100 image_token_index: int = 32000 vision_feature_select_strategy: str = "default" vision_feature_layer: int = -2 vocab_size: int = 32000 @classmethod def from_dict(cls, params): return cls( **{ k: v for k, v in params.items() if k in inspect.signature(cls).parameters } ) class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlaVAConfig): super().__init__() self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, bias=True ) self.gelu = nn.GELU() self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=True ) def __call__(self, x: mx.array) -> mx.array: x = self.linear_1(x) x = self.gelu(x) x = self.linear_2(x) return x class LlavaModel(nn.Module): def __init__(self, config: LlaVAConfig): self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) self.multi_modal_projector = LlavaMultiModalProjector(config) self.vision_feature_layer = config.vision_feature_layer self.vision_feature_select_strategy = config.vision_feature_select_strategy def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, ): # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) if pixel_values is None: return inputs_embeds # Get the ouptut hidden states from the vision model *_, hidden_states = self.vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True ) # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] if self.vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif self.vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature else: raise ValueError( "Unexpected feature selection strategy: " f"{self.vision_feature_select_strategy}" ) # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) # Insert special image tokens in the input_ids final_inputs_embeds = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids ) return final_inputs_embeds def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index batch_size, num_image_patches, embed_dim = image_features.shape # Positions of tokens in input_ids, assuming batch size is 1 image_positions = mx.array( np.where(input_ids[0] == image_token_index)[0], mx.uint32 ) if len(image_positions) != num_image_patches: raise ValueError( f"The number of image tokens ({len(image_positions)}) does not " f" match the number of image patches ({num_image_patches})." ) inputs_embeds[0, image_positions] = image_features return inputs_embeds def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): input_embddings = self.get_input_embeddings(input_ids, pixel_values) logits, cache = self.language_model( input_ids, cache=cache, inputs_embeds=input_embddings ) return logits, cache @staticmethod def from_pretrained(path_or_hf_repo: str): path = Path(path_or_hf_repo) if not path.exists(): path = Path( snapshot_download( repo_id=path_or_hf_repo, allow_patterns=[ "*.json", "*.safetensors", "*.py", "tokenizer.model", "*.tiktoken", ], ) ) with open(path / "config.json", "r") as f: model_config = json.load(f) model_config = LlaVAConfig.from_dict(model_config) model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) model_config.text_config = TextConfig.from_dict(model_config.text_config) model = LlavaModel(model_config) weight_files = glob.glob(str(path / "*.safetensors")) if not weight_files: raise FileNotFoundError(f"No safetensors found in {path}") weights = {} for wf in weight_files: weights.update(mx.load(wf)) weights = VisionModel.sanitize(weights) weights = LanguageModel.sanitize(weights) model.load_weights(list(weights.items())) return model