mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
LlaVA in MLX (#461)
* add: llava mlx first draft * add: weights comparision * add forward pass skeleton * update: now imports weights correctly * delete base * latest * adding config * fix: use config * add mlx config * feat: add image processor for llava processor * wip * feat: llava working example * chore: refactor generate script * chore: clean up * add: warning to user if no <image> token despite using one * add: __call__ to LlavaModel * add: call to LlavaModel * update fp * clean up var names * update: native GeLU * Cleanup * update generate and readme * remove todo comment * rearrange tests * fix example code * nits in README * update readme * nit in readme * nits in README * chore(llava): refactor image embedding merging logic * min mlx version * nits in readmes * fix cli prompt, some nits * updates, slight simplify --------- Co-authored-by: anchen <li.anchen.au@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
261f1280f6
commit
a429263905
@ -31,6 +31,7 @@ Some more useful examples are listed below.
|
||||
### Multimodal models
|
||||
|
||||
- Joint text and image embeddings with [CLIP](clip).
|
||||
- Text generation from image and text inputs with [LLaVA](llava).
|
||||
|
||||
### Other Models
|
||||
|
||||
|
1
llava/.gitignore
vendored
Normal file
1
llava/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
**.ipynb
|
61
llava/README.md
Normal file
61
llava/README.md
Normal file
@ -0,0 +1,61 @@
|
||||
# LLaVA
|
||||
|
||||
An example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is
|
||||
a multimodal model that can generate text given combined image and text inputs.
|
||||
|
||||
## Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
You can use LLaVA to ask questions about images.
|
||||
|
||||
For example, using the command line:
|
||||
|
||||
```bash
|
||||
python generate.py \
|
||||
--model llava-hf/llava-1.5-7b-hf \
|
||||
--image "http://images.cocodataset.org/val2017/000000039769.jpg" \
|
||||
--prompt "USER: <image>\nWhat are these?\nASSISTANT:" \
|
||||
--max-tokens 128 \
|
||||
--temp 0
|
||||
```
|
||||
|
||||
This uses the following image:
|
||||
|
||||

|
||||
|
||||
And generates the output:
|
||||
|
||||
```
|
||||
These are two cats lying on a pink couch.
|
||||
```
|
||||
|
||||
You can also use LLaVA in Python:
|
||||
|
||||
```python
|
||||
from generate import load_model, prepare_inputs, generate_text
|
||||
|
||||
processor, model = load_model("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
max_tokens, temperature = 128, 0.0
|
||||
|
||||
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
|
||||
image = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
input_ids, pixel_values = prepare_inputs(processor, image, prompt)
|
||||
|
||||
reply = generate_text(
|
||||
input_ids, pixel_values, model, processor, max_tokens, temperature
|
||||
)
|
||||
|
||||
print(reply)
|
||||
```
|
||||
|
||||
[^1]:
|
||||
Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more
|
||||
information.
|
130
llava/generate.py
Normal file
130
llava/generate.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from llava import LlavaModel
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate text from an image using a model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="llava-hf/llava-1.5-7b-hf",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default="http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
help="URL or path of the image to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="USER: <image>\nWhat are these?\nASSISTANT:",
|
||||
help="Message to be processed by the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.3, help="Temperature for sampling."
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_image(image_source):
|
||||
"""
|
||||
Helper function to load an image from either a URL or file.
|
||||
"""
|
||||
if image_source.startswith(("http://", "https://")):
|
||||
try:
|
||||
response = requests.get(image_source, stream=True)
|
||||
response.raise_for_status()
|
||||
return Image.open(response.raw)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to load image from URL: {image_source} with error {e}"
|
||||
)
|
||||
elif Path(image_source).is_file():
|
||||
try:
|
||||
return Image.open(image_source)
|
||||
except IOError as e:
|
||||
raise ValueError(f"Failed to load image {image_source} with error: {e}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The image {image_source} must be a valid URL or existing file."
|
||||
)
|
||||
|
||||
|
||||
def prepare_inputs(processor, image, prompt):
|
||||
if isinstance(image, str):
|
||||
image = load_image(image)
|
||||
inputs = processor(prompt, image, return_tensors="np")
|
||||
pixel_values = mx.array(inputs["pixel_values"])
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
return input_ids, pixel_values
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
model = LlavaModel.from_pretrained(model_path)
|
||||
return processor, model
|
||||
|
||||
|
||||
def sample(logits, temperature=0.0):
|
||||
if temperature == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temperature))
|
||||
|
||||
|
||||
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
|
||||
|
||||
logits, cache = model(input_ids, pixel_values)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits, temperature=temperature)
|
||||
tokens = [y.item()]
|
||||
|
||||
for n in range(max_tokens - 1):
|
||||
logits, cache = model.language_model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits, temperature)
|
||||
token = y.item()
|
||||
if token == processor.tokenizer.eos_token_id:
|
||||
break
|
||||
tokens.append(token)
|
||||
|
||||
return processor.tokenizer.decode(tokens)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
processor, model = load_model(args.model)
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
|
||||
|
||||
print(prompt)
|
||||
generated_text = generate_text(
|
||||
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
|
||||
)
|
||||
print(generated_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
231
llava/language.py
Normal file
231
llava/language.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextConfig:
|
||||
model_type: str
|
||||
hidden_size: int = 4096
|
||||
num_hidden_layers: int = 32
|
||||
intermediate_size: int = 11008
|
||||
num_attention_heads: int = 32
|
||||
rms_norm_eps: float = 1e-6
|
||||
vocab_size: int = 32000
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
|
||||
dim = config.hidden_size
|
||||
self.n_heads = n_heads = config.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
head_dim = config.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (
|
||||
1 / config.rope_scaling["factor"]
|
||||
if config.rope_scaling is not None
|
||||
and config.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=config.rope_traditional,
|
||||
base=config.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if self.repeats > 1:
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
# for passing merged input embeddings
|
||||
if inputs_embeds is None:
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs_embeds
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
if self.model_type != "llama":
|
||||
raise ValueError(
|
||||
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
|
||||
)
|
||||
self.model = Llama(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache, inputs_embeds)
|
||||
return self.lm_head(out), cache
|
||||
|
||||
@staticmethod
|
||||
def sanitize(weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
179
llava/llava.py
Normal file
179
llava/llava.py
Normal file
@ -0,0 +1,179 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from language import LanguageModel, TextConfig
|
||||
from vision import VisionConfig, VisionModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlaVAConfig:
|
||||
text_config: TextConfig
|
||||
vision_config: VisionConfig
|
||||
ignore_index: int = -100
|
||||
image_token_index: int = 32000
|
||||
vision_feature_select_strategy: str = "default"
|
||||
vision_feature_layer: int = -2
|
||||
vocab_size: int = 32000
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlaVAConfig):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
)
|
||||
self.gelu = nn.GELU()
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.linear_1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlavaModel(nn.Module):
|
||||
def __init__(self, config: LlaVAConfig):
|
||||
self.config = config
|
||||
self.vision_tower = VisionModel(config.vision_config)
|
||||
self.language_model = LanguageModel(config.text_config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.vision_feature_layer = config.vision_feature_layer
|
||||
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Optional[mx.array] = None,
|
||||
pixel_values: Optional[mx.array] = None,
|
||||
):
|
||||
if pixel_values is None:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
# Get the input embeddings from the language model
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Get the ouptut hidden states from the vision model
|
||||
*_, hidden_states = self.vision_tower(
|
||||
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
|
||||
)
|
||||
|
||||
# Select the hidden states from the desired layer
|
||||
selected_image_feature = hidden_states[self.vision_feature_layer]
|
||||
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unexpected feature selection strategy: "
|
||||
f"{self.vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
# Pass image features through the multi-modal projector
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# Insert special image tokens in the input_ids
|
||||
final_inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids
|
||||
)
|
||||
return final_inputs_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids
|
||||
):
|
||||
image_token_index = self.config.image_token_index
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
|
||||
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
||||
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
||||
|
||||
if len(image_positions) != num_images:
|
||||
raise ValueError(
|
||||
f"The number of image tokens ({len(image_positions)}) does not "
|
||||
f" match the number of image inputs ({num_images})."
|
||||
)
|
||||
|
||||
text_segments = []
|
||||
start_idx = 0
|
||||
|
||||
for position in image_positions:
|
||||
text_segments.append(inputs_embeds[:, start_idx:position])
|
||||
start_idx = position + 1
|
||||
|
||||
image_embeddings = mx.split(image_features, image_features.shape[0])
|
||||
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
|
||||
final_embeddings += [inputs_embeds[:, start_idx:]]
|
||||
|
||||
# Create a final embedding of shape
|
||||
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||
return mx.concatenate(final_embeddings, axis=1)
|
||||
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
logits, cache = self.language_model(
|
||||
input_ids, cache=cache, inputs_embeds=input_embddings
|
||||
)
|
||||
return logits, cache
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path_or_hf_repo: str):
|
||||
path = Path(path_or_hf_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
model_config = LlaVAConfig.from_dict(model_config)
|
||||
|
||||
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
|
||||
model_config.text_config = TextConfig.from_dict(model_config.text_config)
|
||||
|
||||
model = LlavaModel(model_config)
|
||||
weight_files = glob.glob(str(path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
weights = VisionModel.sanitize(weights)
|
||||
weights = LanguageModel.sanitize(weights)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
return model
|
6
llava/requirements.txt
Normal file
6
llava/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
mlx>=0.5.0
|
||||
numpy
|
||||
transformers
|
||||
torch
|
||||
huggingface_hub
|
||||
Pillow
|
162
llava/test.py
Normal file
162
llava/test.py
Normal file
@ -0,0 +1,162 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
|
||||
from llava import LlavaModel
|
||||
|
||||
MODEL_PATH = "llava-hf/llava-1.5-7b-hf"
|
||||
PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
|
||||
IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
|
||||
|
||||
def load_mlx_models(path):
|
||||
model = LlavaModel.from_pretrained(path)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_hf_models(path):
|
||||
model = LlavaForConditionalGeneration.from_pretrained(path)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class TestVisionTower(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mx_llava = load_mlx_models(MODEL_PATH)
|
||||
cls.hf_llava = load_hf_models(MODEL_PATH)
|
||||
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
def test_image_features(self):
|
||||
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
||||
vision_feature_layer = -2
|
||||
with torch.no_grad():
|
||||
pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[
|
||||
"pixel_values"
|
||||
]
|
||||
|
||||
hf_pixel_values = pixel_values
|
||||
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
|
||||
|
||||
_, _, hidden_states = self.mx_llava.vision_tower(
|
||||
mx_pixel_values,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
mx_elected_image_feature = hidden_states[vision_feature_layer]
|
||||
mx_image_features = self.mx_llava.multi_modal_projector(
|
||||
mx_elected_image_feature
|
||||
)
|
||||
|
||||
hf_image_outputs = self.hf_llava.vision_tower(
|
||||
hf_pixel_values, output_hidden_states=True
|
||||
)
|
||||
hf_elected_image_feature = hf_image_outputs.hidden_states[
|
||||
vision_feature_layer
|
||||
]
|
||||
hf_image_features = self.hf_llava.multi_modal_projector(
|
||||
hf_elected_image_feature
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx_image_features,
|
||||
mx.array(hf_image_features.numpy()),
|
||||
atol=1e-2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestLlava(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mx_llava = load_mlx_models(MODEL_PATH)
|
||||
cls.hf_llava = load_hf_models(MODEL_PATH)
|
||||
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
def test_merge_input_ids_with_image_features(self):
|
||||
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
||||
vision_feature_layer = -2
|
||||
with torch.no_grad():
|
||||
values = self.proc(PROMPT, raw_image, return_tensors="pt")
|
||||
pixel_values = values["pixel_values"]
|
||||
input_ids = values["input_ids"]
|
||||
|
||||
hf_pixel_values = pixel_values
|
||||
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
|
||||
|
||||
_, _, hidden_states = self.mx_llava.vision_tower(
|
||||
mx_pixel_values,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
mx_input_ids = mx.array(input_ids.numpy())
|
||||
mx_elected_image_feature = hidden_states[vision_feature_layer]
|
||||
mx_image_features = self.mx_llava.multi_modal_projector(
|
||||
mx_elected_image_feature
|
||||
)
|
||||
mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens(
|
||||
mx_input_ids
|
||||
)
|
||||
mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features(
|
||||
mx_image_features, mx_inputs_embeds, mx_input_ids
|
||||
)
|
||||
|
||||
hf_image_outputs = self.hf_llava.vision_tower(
|
||||
hf_pixel_values, output_hidden_states=True
|
||||
)
|
||||
hf_elected_image_feature = hf_image_outputs.hidden_states[
|
||||
vision_feature_layer
|
||||
]
|
||||
hf_image_features = self.hf_llava.multi_modal_projector(
|
||||
hf_elected_image_feature
|
||||
)
|
||||
hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids)
|
||||
hf_final_embedding, _, _, _ = (
|
||||
self.hf_llava._merge_input_ids_with_image_features(
|
||||
hf_image_features,
|
||||
hf_inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask=input_ids,
|
||||
labels=torch.ones_like(input_ids),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx_final_embedding,
|
||||
mx.array(hf_final_embedding.numpy()),
|
||||
atol=1e-1,
|
||||
)
|
||||
)
|
||||
|
||||
def test_generated_tokens(self):
|
||||
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
|
||||
with torch.no_grad():
|
||||
hf_inputs = self.proc(PROMPT, raw_image, return_tensors="pt")
|
||||
hf_outputs = self.hf_llava(**hf_inputs)
|
||||
hf_logits = hf_outputs.logits
|
||||
|
||||
mx_inputs = self.proc(PROMPT, raw_image, return_tensors="np")
|
||||
pixel_values = mx.array(mx_inputs["pixel_values"])
|
||||
input_ids = mx.array(mx_inputs["input_ids"])
|
||||
|
||||
mx_logits, _ = self.mx_llava(input_ids, pixel_values)
|
||||
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx_logits[:, -1, :].argmax(axis=-1),
|
||||
mx.array(hf_logits.numpy())[:, -1, :].argmax(axis=-1),
|
||||
atol=1e-2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
223
llava/vision.py
Normal file
223
llava/vision.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionConfig:
|
||||
model_type: str
|
||||
num_hidden_layers: int = 24
|
||||
hidden_size: int = 1024
|
||||
intermediate_size: int = 4096
|
||||
num_attention_heads: int = 16
|
||||
image_size: int = 336
|
||||
patch_size: int = 14
|
||||
projection_dim: int = 768
|
||||
vocab_size: int = 32000
|
||||
num_channels: int = 3
|
||||
layer_norm_eps: float = 1e-5
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
"The input feature dimensions should be divisible by the "
|
||||
f"number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
||||
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
||||
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
||||
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.q_proj(queries)
|
||||
keys = self.k_proj(keys)
|
||||
values = self.v_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.activation_fn = nn.GELU(approx="fast")
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Attention(
|
||||
config.hidden_size, config.num_attention_heads, bias=True
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = MLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
||||
y = self.layer_norm1(x)
|
||||
y = self.self_attn(y, y, y, mask)
|
||||
x = x + y
|
||||
y = self.layer_norm2(x)
|
||||
y = self.mlp(y)
|
||||
return x + y
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
|
||||
class VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = mx.zeros((config.hidden_size,))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
batch_size = x.shape[0]
|
||||
patch_embeddings = self.patch_embedding(x)
|
||||
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
|
||||
embed_dim = patch_embeddings.shape[-1]
|
||||
cls_embeddings = mx.broadcast_to(
|
||||
self.class_embedding, (batch_size, 1, embed_dim)
|
||||
)
|
||||
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
|
||||
embeddings += self.position_embedding.weight
|
||||
return embeddings
|
||||
|
||||
|
||||
class ClipVisionModel(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.embeddings = VisionEmbeddings(config)
|
||||
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
|
||||
self.encoder = Encoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> mx.array:
|
||||
x = self.embeddings(x)
|
||||
x = self.pre_layrnorm(x)
|
||||
|
||||
encoder_states = (x,) if output_hidden_states else None
|
||||
|
||||
for l in self.encoder.layers:
|
||||
x = l(x, mask=None)
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (x,)
|
||||
|
||||
pooler_output = self.post_layernorm(x[:, 0, :])
|
||||
return pooler_output, x, encoder_states
|
||||
|
||||
|
||||
class VisionModel(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.model_type = config.model_type
|
||||
if self.model_type != "clip_vision_model":
|
||||
raise ValueError(f"Unsupported model type: {self.model_type}")
|
||||
|
||||
self.vision_model = ClipVisionModel(config)
|
||||
|
||||
def __call__(
|
||||
self, x: mx.array, output_hidden_states: Optional[bool] = None
|
||||
) -> mx.array:
|
||||
return self.vision_model(x, output_hidden_states)
|
||||
|
||||
@staticmethod
|
||||
def sanitize(weights):
|
||||
sanitized_weights = {}
|
||||
for k, v in weights.items():
|
||||
if "position_ids" in k:
|
||||
# Remove unused position_ids
|
||||
continue
|
||||
elif "patch_embedding.weight" in k:
|
||||
# PyTorch conv2d weight tensors have shape:
|
||||
# [out_channels, in_channels, kH, KW]
|
||||
# MLX conv2d expects the weight be of shape:
|
||||
# [out_channels, kH, KW, in_channels]
|
||||
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
||||
else:
|
||||
sanitized_weights[k] = v
|
||||
|
||||
return sanitized_weights
|
Loading…
Reference in New Issue
Block a user