mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
155 lines
4.6 KiB
Python
155 lines
4.6 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
|
|
|
|
|
@dataclass
|
|
class CLIPTextModelConfig:
|
|
num_layers: int = 23
|
|
model_dims: int = 1024
|
|
num_heads: int = 16
|
|
max_length: int = 77
|
|
vocab_size: int = 49408
|
|
hidden_act: str = "quick_gelu"
|
|
|
|
@classmethod
|
|
def from_dict(cls, config):
|
|
return cls(
|
|
num_layers=config["num_hidden_layers"],
|
|
model_dims=config["hidden_size"],
|
|
num_heads=config["num_attention_heads"],
|
|
max_length=config["max_position_embeddings"],
|
|
vocab_size=config["vocab_size"],
|
|
hidden_act=config["hidden_act"],
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CLIPOutput:
|
|
# The last_hidden_state indexed at the EOS token and possibly projected if
|
|
# the model has a projection layer
|
|
pooled_output: Optional[mx.array] = None
|
|
|
|
# The full sequence output of the transformer after the final layernorm
|
|
last_hidden_state: Optional[mx.array] = None
|
|
|
|
# A list of hidden states corresponding to the outputs of the transformer layers
|
|
hidden_states: Optional[List[mx.array]] = None
|
|
|
|
|
|
class CLIPEncoderLayer(nn.Module):
|
|
"""The transformer encoder layer from CLIP."""
|
|
|
|
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
|
super().__init__()
|
|
|
|
self.layer_norm1 = nn.LayerNorm(model_dims)
|
|
self.layer_norm2 = nn.LayerNorm(model_dims)
|
|
|
|
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
|
|
|
|
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
|
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
|
|
|
self.act = _ACTIVATIONS[activation]
|
|
|
|
def __call__(self, x, attn_mask=None):
|
|
y = self.layer_norm1(x)
|
|
y = self.attention(y, y, y, attn_mask)
|
|
x = y + x
|
|
|
|
y = self.layer_norm2(x)
|
|
y = self.linear1(y)
|
|
y = self.act(y)
|
|
y = self.linear2(y)
|
|
x = y + x
|
|
|
|
return x
|
|
|
|
|
|
class CLIPTextModel(nn.Module):
|
|
"""Implements the text encoder transformer from CLIP."""
|
|
|
|
def __init__(self, config: CLIPTextModelConfig):
|
|
super().__init__()
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
|
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
|
self.layers = [
|
|
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
|
for i in range(config.num_layers)
|
|
]
|
|
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
|
|
|
def _get_mask(self, N, dtype):
|
|
indices = mx.arange(N)
|
|
mask = indices[:, None] < indices[None]
|
|
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
|
return mask
|
|
|
|
def sanitize(self, weights):
|
|
new_weights = {}
|
|
for key, w in weights.items():
|
|
# Remove prefixes
|
|
if key.startswith("text_model."):
|
|
key = key[11:]
|
|
if key.startswith("embeddings."):
|
|
key = key[11:]
|
|
if key.startswith("encoder."):
|
|
key = key[8:]
|
|
|
|
# Map attention layers
|
|
if "self_attn." in key:
|
|
key = key.replace("self_attn.", "attention.")
|
|
if "q_proj." in key:
|
|
key = key.replace("q_proj.", "query_proj.")
|
|
if "k_proj." in key:
|
|
key = key.replace("k_proj.", "key_proj.")
|
|
if "v_proj." in key:
|
|
key = key.replace("v_proj.", "value_proj.")
|
|
|
|
# Map ffn layers
|
|
if "mlp.fc1" in key:
|
|
key = key.replace("mlp.fc1", "linear1")
|
|
if "mlp.fc2" in key:
|
|
key = key.replace("mlp.fc2", "linear2")
|
|
|
|
new_weights[key] = w
|
|
|
|
return new_weights
|
|
|
|
def __call__(self, x):
|
|
# Extract some shapes
|
|
B, N = x.shape
|
|
eos_tokens = x.argmax(-1)
|
|
|
|
# Compute the embeddings
|
|
x = self.token_embedding(x)
|
|
x = x + self.position_embedding.weight[:N]
|
|
|
|
# Compute the features from the transformer
|
|
mask = self._get_mask(N, x.dtype)
|
|
hidden_states = []
|
|
for l in self.layers:
|
|
x = l(x, mask)
|
|
hidden_states.append(x)
|
|
|
|
# Apply the final layernorm and return
|
|
x = self.final_layer_norm(x)
|
|
last_hidden_state = x
|
|
|
|
# Select the EOS token
|
|
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
|
|
|
return CLIPOutput(
|
|
pooled_output=pooled_output,
|
|
last_hidden_state=last_hidden_state,
|
|
hidden_states=hidden_states,
|
|
)
|