LlaVA in MLX (#461)

* add: llava mlx first draft

* add: weights comparision

* add forward pass skeleton

* update: now  imports weights correctly

* delete base

* latest

* adding config

* fix: use config

* add mlx config

* feat: add image processor for llava processor

* wip

* feat: llava working example

* chore: refactor generate script

* chore: clean up

* add: warning to user if no <image> token despite using one

* add: __call__ to LlavaModel

* add: call to LlavaModel

* update fp

* clean up var names

* update: native GeLU

* Cleanup

* update generate and readme

* remove todo comment

* rearrange tests

* fix example code

* nits in README

* update readme

* nit in readme

* nits in README

* chore(llava): refactor image embedding merging logic

* min mlx version

* nits in readmes

* fix cli prompt, some nits

* updates, slight simplify

---------

Co-authored-by: anchen <li.anchen.au@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Noah Kasmanoff 2024-03-01 13:28:35 -05:00 committed by GitHub
parent 261f1280f6
commit a429263905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 994 additions and 0 deletions

View File

@ -31,6 +31,7 @@ Some more useful examples are listed below.
### Multimodal models
- Joint text and image embeddings with [CLIP](clip).
- Text generation from image and text inputs with [LLaVA](llava).
### Other Models

1
llava/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
**.ipynb

61
llava/README.md Normal file
View File

@ -0,0 +1,61 @@
# LLaVA
An example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is
a multimodal model that can generate text given combined image and text inputs.
## Setup
Install the dependencies:
```bash
pip install -r requirements.txt
```
## Run
You can use LLaVA to ask questions about images.
For example, using the command line:
```bash
python generate.py \
--model llava-hf/llava-1.5-7b-hf \
--image "http://images.cocodataset.org/val2017/000000039769.jpg" \
--prompt "USER: <image>\nWhat are these?\nASSISTANT:" \
--max-tokens 128 \
--temp 0
```
This uses the following image:
![alt text](http://images.cocodataset.org/val2017/000000039769.jpg)
And generates the output:
```
These are two cats lying on a pink couch.
```
You can also use LLaVA in Python:
```python
from generate import load_model, prepare_inputs, generate_text
processor, model = load_model("llava-hf/llava-1.5-7b-hf")
max_tokens, temperature = 128, 0.0
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image = "http://images.cocodataset.org/val2017/000000039769.jpg"
input_ids, pixel_values = prepare_inputs(processor, image, prompt)
reply = generate_text(
input_ids, pixel_values, model, processor, max_tokens, temperature
)
print(reply)
```
[^1]:
Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more
information.

130
llava/generate.py Normal file
View File

@ -0,0 +1,130 @@
# Copyright © 2024 Apple Inc.
import argparse
import codecs
from pathlib import Path
import mlx.core as mx
import requests
from PIL import Image
from transformers import AutoProcessor
from llava import LlavaModel
def parse_arguments():
parser = argparse.ArgumentParser(
description="Generate text from an image using a model."
)
parser.add_argument(
"--model",
type=str,
default="llava-hf/llava-1.5-7b-hf",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--image",
type=str,
default="http://images.cocodataset.org/val2017/000000039769.jpg",
help="URL or path of the image to process.",
)
parser.add_argument(
"--prompt",
type=str,
default="USER: <image>\nWhat are these?\nASSISTANT:",
help="Message to be processed by the model.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--temp", type=float, default=0.3, help="Temperature for sampling."
)
return parser.parse_args()
def load_image(image_source):
"""
Helper function to load an image from either a URL or file.
"""
if image_source.startswith(("http://", "https://")):
try:
response = requests.get(image_source, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise ValueError(
f"Failed to load image from URL: {image_source} with error {e}"
)
elif Path(image_source).is_file():
try:
return Image.open(image_source)
except IOError as e:
raise ValueError(f"Failed to load image {image_source} with error: {e}")
else:
raise ValueError(
f"The image {image_source} must be a valid URL or existing file."
)
def prepare_inputs(processor, image, prompt):
if isinstance(image, str):
image = load_image(image)
inputs = processor(prompt, image, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
return input_ids, pixel_values
def load_model(model_path):
processor = AutoProcessor.from_pretrained(model_path)
model = LlavaModel.from_pretrained(model_path)
return processor, model
def sample(logits, temperature=0.0):
if temperature == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temperature))
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
logits, cache = model(input_ids, pixel_values)
logits = logits[:, -1, :]
y = sample(logits, temperature=temperature)
tokens = [y.item()]
for n in range(max_tokens - 1):
logits, cache = model.language_model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits, temperature)
token = y.item()
if token == processor.tokenizer.eos_token_id:
break
tokens.append(token)
return processor.tokenizer.decode(tokens)
def main():
args = parse_arguments()
processor, model = load_model(args.model)
prompt = codecs.decode(args.prompt, "unicode_escape")
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
print(prompt)
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)
if __name__ == "__main__":
main()

231
llava/language.py Normal file
View File

@ -0,0 +1,231 @@
# Copyright © 2024 Apple Inc.
import inspect
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@dataclass
class TextConfig:
model_type: str
hidden_size: int = 4096
num_hidden_layers: int = 32
intermediate_size: int = 11008
num_attention_heads: int = 32
rms_norm_eps: float = 1e-6
vocab_size: int = 32000
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
dim = config.hidden_size
self.n_heads = n_heads = config.num_attention_heads
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = config.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / config.rope_scaling["factor"]
if config.rope_scaling is not None
and config.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.config = config
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class Llama(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
):
# for passing merged input embeddings
if inputs_embeds is None:
h = self.embed_tokens(inputs)
else:
h = inputs_embeds
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.model_type = config.model_type
if self.model_type != "llama":
raise ValueError(
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
)
self.model = Llama(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
):
out, cache = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out), cache
@staticmethod
def sanitize(weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

179
llava/llava.py Normal file
View File

@ -0,0 +1,179 @@
# Copyright © 2024 Apple Inc.
import glob
import inspect
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download
from language import LanguageModel, TextConfig
from vision import VisionConfig, VisionModel
@dataclass
class LlaVAConfig:
text_config: TextConfig
vision_config: VisionConfig
ignore_index: int = -100
image_token_index: int = 32000
vision_feature_select_strategy: str = "default"
vision_feature_layer: int = -2
vocab_size: int = 32000
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlaVAConfig):
super().__init__()
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
)
self.gelu = nn.GELU()
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear_1(x)
x = self.gelu(x)
x = self.linear_2(x)
return x
class LlavaModel(nn.Module):
def __init__(self, config: LlaVAConfig):
self.config = config
self.vision_tower = VisionModel(config.vision_config)
self.language_model = LanguageModel(config.text_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy
def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
)
# Select the hidden states from the desired layer
selected_image_feature = hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
"Unexpected feature selection strategy: "
f"{self.vision_feature_select_strategy}"
)
# Pass image features through the multi-modal projector
image_features = self.multi_modal_projector(selected_image_feature)
# Insert special image tokens in the input_ids
final_inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
return final_inputs_embeds
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape
# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
if len(image_positions) != num_images:
raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images})."
)
text_segments = []
start_idx = 0
for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]
# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.language_model(
input_ids, cache=cache, inputs_embeds=input_embddings
)
return logits, cache
@staticmethod
def from_pretrained(path_or_hf_repo: str):
path = Path(path_or_hf_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
],
)
)
with open(path / "config.json", "r") as f:
model_config = json.load(f)
model_config = LlaVAConfig.from_dict(model_config)
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
model_config.text_config = TextConfig.from_dict(model_config.text_config)
model = LlavaModel(model_config)
weight_files = glob.glob(str(path / "*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
weights = VisionModel.sanitize(weights)
weights = LanguageModel.sanitize(weights)
model.load_weights(list(weights.items()))
return model

6
llava/requirements.txt Normal file
View File

@ -0,0 +1,6 @@
mlx>=0.5.0
numpy
transformers
torch
huggingface_hub
Pillow

162
llava/test.py Normal file
View File

@ -0,0 +1,162 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from llava import LlavaModel
MODEL_PATH = "llava-hf/llava-1.5-7b-hf"
PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
def load_mlx_models(path):
model = LlavaModel.from_pretrained(path)
model.eval()
return model
def load_hf_models(path):
model = LlavaForConditionalGeneration.from_pretrained(path)
model.eval()
return model
class TestVisionTower(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mx_llava = load_mlx_models(MODEL_PATH)
cls.hf_llava = load_hf_models(MODEL_PATH)
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
def test_image_features(self):
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
vision_feature_layer = -2
with torch.no_grad():
pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[
"pixel_values"
]
hf_pixel_values = pixel_values
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
_, _, hidden_states = self.mx_llava.vision_tower(
mx_pixel_values,
output_hidden_states=True,
)
mx_elected_image_feature = hidden_states[vision_feature_layer]
mx_image_features = self.mx_llava.multi_modal_projector(
mx_elected_image_feature
)
hf_image_outputs = self.hf_llava.vision_tower(
hf_pixel_values, output_hidden_states=True
)
hf_elected_image_feature = hf_image_outputs.hidden_states[
vision_feature_layer
]
hf_image_features = self.hf_llava.multi_modal_projector(
hf_elected_image_feature
)
self.assertTrue(
mx.allclose(
mx_image_features,
mx.array(hf_image_features.numpy()),
atol=1e-2,
)
)
class TestLlava(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mx_llava = load_mlx_models(MODEL_PATH)
cls.hf_llava = load_hf_models(MODEL_PATH)
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
def test_merge_input_ids_with_image_features(self):
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
vision_feature_layer = -2
with torch.no_grad():
values = self.proc(PROMPT, raw_image, return_tensors="pt")
pixel_values = values["pixel_values"]
input_ids = values["input_ids"]
hf_pixel_values = pixel_values
mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
_, _, hidden_states = self.mx_llava.vision_tower(
mx_pixel_values,
output_hidden_states=True,
)
mx_input_ids = mx.array(input_ids.numpy())
mx_elected_image_feature = hidden_states[vision_feature_layer]
mx_image_features = self.mx_llava.multi_modal_projector(
mx_elected_image_feature
)
mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens(
mx_input_ids
)
mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features(
mx_image_features, mx_inputs_embeds, mx_input_ids
)
hf_image_outputs = self.hf_llava.vision_tower(
hf_pixel_values, output_hidden_states=True
)
hf_elected_image_feature = hf_image_outputs.hidden_states[
vision_feature_layer
]
hf_image_features = self.hf_llava.multi_modal_projector(
hf_elected_image_feature
)
hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids)
hf_final_embedding, _, _, _ = (
self.hf_llava._merge_input_ids_with_image_features(
hf_image_features,
hf_inputs_embeds,
input_ids,
attention_mask=input_ids,
labels=torch.ones_like(input_ids),
)
)
self.assertTrue(
mx.allclose(
mx_final_embedding,
mx.array(hf_final_embedding.numpy()),
atol=1e-1,
)
)
def test_generated_tokens(self):
raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
with torch.no_grad():
hf_inputs = self.proc(PROMPT, raw_image, return_tensors="pt")
hf_outputs = self.hf_llava(**hf_inputs)
hf_logits = hf_outputs.logits
mx_inputs = self.proc(PROMPT, raw_image, return_tensors="np")
pixel_values = mx.array(mx_inputs["pixel_values"])
input_ids = mx.array(mx_inputs["input_ids"])
mx_logits, _ = self.mx_llava(input_ids, pixel_values)
self.assertTrue(
mx.allclose(
mx_logits[:, -1, :].argmax(axis=-1),
mx.array(hf_logits.numpy())[:, -1, :].argmax(axis=-1),
atol=1e-2,
)
)
if __name__ == "__main__":
unittest.main()

223
llava/vision.py Normal file
View File

@ -0,0 +1,223 @@
# Copyright © 2024 Apple Inc.
import inspect
import math
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@dataclass
class VisionConfig:
model_type: str
num_hidden_layers: int = 24
hidden_size: int = 1024
intermediate_size: int = 4096
num_attention_heads: int = 16
image_size: int = 336
patch_size: int = 14
projection_dim: int = 768
vocab_size: int = 32000
num_channels: int = 3
layer_norm_eps: float = 1e-5
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class Attention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
def __call__(self, queries, keys, values, mask=None):
queries = self.q_proj(queries)
keys = self.k_proj(keys)
values = self.v_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
class MLP(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.activation_fn = nn.GELU(approx="fast")
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def __call__(self, x: mx.array) -> mx.array:
x = self.activation_fn(self.fc1(x))
x = self.fc2(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Attention(
config.hidden_size, config.num_attention_heads, bias=True
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
y = self.layer_norm1(x)
y = self.self_attn(y, y, y, mask)
x = x + y
y = self.layer_norm2(x)
y = self.mlp(y)
return x + y
class Encoder(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
class VisionEmbeddings(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = mx.zeros((config.hidden_size,))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def __call__(self, x: mx.array) -> mx.array:
batch_size = x.shape[0]
patch_embeddings = self.patch_embedding(x)
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
embed_dim = patch_embeddings.shape[-1]
cls_embeddings = mx.broadcast_to(
self.class_embedding, (batch_size, 1, embed_dim)
)
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
embeddings += self.position_embedding.weight
return embeddings
class ClipVisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.embeddings = VisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
self.encoder = Encoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def __call__(
self,
x: mx.array,
output_hidden_states: Optional[bool] = None,
) -> mx.array:
x = self.embeddings(x)
x = self.pre_layrnorm(x)
encoder_states = (x,) if output_hidden_states else None
for l in self.encoder.layers:
x = l(x, mask=None)
if output_hidden_states:
encoder_states = encoder_states + (x,)
pooler_output = self.post_layernorm(x[:, 0, :])
return pooler_output, x, encoder_states
class VisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.model_type = config.model_type
if self.model_type != "clip_vision_model":
raise ValueError(f"Unsupported model type: {self.model_type}")
self.vision_model = ClipVisionModel(config)
def __call__(
self, x: mx.array, output_hidden_states: Optional[bool] = None
) -> mx.array:
return self.vision_model(x, output_hidden_states)
@staticmethod
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if "position_ids" in k:
# Remove unused position_ids
continue
elif "patch_embedding.weight" in k:
# PyTorch conv2d weight tensors have shape:
# [out_channels, in_channels, kH, KW]
# MLX conv2d expects the weight be of shape:
# [out_channels, kH, KW, in_channels]
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
else:
sanitized_weights[k] = v
return sanitized_weights