mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Working clip, t5 and flux model
This commit is contained in:
152
flux/flux/clip.py
Normal file
152
flux/flux/clip.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
hidden_act: str = "quick_gelu"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return cls(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
hidden_act=config["hidden_act"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput:
|
||||
# The last_hidden_state indexed at the EOS token and possibly projected if
|
||||
# the model has a projection layer
|
||||
pooled_output: Optional[mx.array] = None
|
||||
|
||||
# The full sequence output of the transformer after the final layernorm
|
||||
last_hidden_state: Optional[mx.array] = None
|
||||
|
||||
# A list of hidden states corresponding to the outputs of the transformer layers
|
||||
hidden_states: Optional[List[mx.array]] = None
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
self.act = _ACTIVATIONS[activation]
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = self.act(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
def _get_mask(self, N, dtype):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||
return mask
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for key, w in weights.items():
|
||||
# Remove prefixes
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
new_weights[key] = w
|
||||
|
||||
return new_weights
|
||||
|
||||
def __call__(self, x):
|
||||
# Extract some shapes
|
||||
B, N = x.shape
|
||||
eos_tokens = x.argmax(-1)
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
|
||||
# Compute the features from the transformer
|
||||
mask = self._get_mask(N, x.dtype)
|
||||
hidden_states = []
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
hidden_states.append(x)
|
||||
|
||||
# Apply the final layernorm and return
|
||||
x = self.final_layer_norm(x)
|
||||
last_hidden_state = x
|
||||
|
||||
# Select the EOS token
|
||||
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
||||
|
||||
return CLIPOutput(
|
||||
pooled_output=pooled_output,
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=hidden_states,
|
||||
)
|
@@ -1,10 +1,78 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def _rope(pos: mx.array, dim: int, theta: float):
|
||||
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
x = pos[..., None] * omega
|
||||
cosx = mx.cos(x)
|
||||
sinx = mx.sin(x)
|
||||
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
|
||||
pe = pe.reshape(*pe.shape[:-1], 2, 2)
|
||||
|
||||
return pe
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def _ab_plus_cd(a, b, c, d):
|
||||
return a * b + c * d
|
||||
|
||||
|
||||
def _apply_rope(x, pe):
|
||||
s = x.shape
|
||||
x = x.reshape(*s[:-1], -1, 1, 2)
|
||||
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
|
||||
return x.reshape(s)
|
||||
|
||||
|
||||
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
|
||||
B, H, L, D = q.shape
|
||||
|
||||
q = _apply_rope(q, pe)
|
||||
k = _apply_rope(k, pe)
|
||||
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
|
||||
|
||||
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
|
||||
def timestep_embedding(
|
||||
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
|
||||
):
|
||||
half = dim // 2
|
||||
freqs = mx.arange(0, half, dtype=mx.float32) / half
|
||||
freqs = freqs * (-math.log(max_period))
|
||||
freqs = mx.exp(freqs)
|
||||
|
||||
x = (time_factor * t)[:, None] * freqs[None]
|
||||
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
|
||||
|
||||
return x.astype(t.dtype)
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def __call__(self, ids: mx.array):
|
||||
n_axes = ids.shape[-1]
|
||||
pe = mx.concatenate(
|
||||
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
axis=-3,
|
||||
)
|
||||
|
||||
return pe[:, None]
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
@@ -34,7 +102,6 @@ class SelfAttention(nn.Module):
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.rope = nn.RoPE(head_dim, True, base=10000)
|
||||
|
||||
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
|
||||
H = self.num_heads
|
||||
@@ -45,10 +112,7 @@ class SelfAttention(nn.Module):
|
||||
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
q, k = self.norm(q, k)
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
|
||||
x = x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
x = _attention(q, k, v, pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
@@ -95,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.GELU(approx="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(
|
||||
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
||||
)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.GELU(approx="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
@@ -130,7 +194,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_q, img_k = self.norm(img_q, img_k)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
@@ -140,19 +204,14 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_q, txt_k = self.norm(txt_q, txt_k)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
||||
|
||||
# run actual attention
|
||||
q = mx.concatenate([txt_q, img_q], axis=2)
|
||||
k = mx.concatenate([txt_k, img_k], axis=2)
|
||||
v = mx.concatenate([txt_v, img_v], axis=2)
|
||||
|
||||
q = self.img_attn.rope(q)
|
||||
k = self.img_attn.rope(k)
|
||||
attn = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=q.shape[-1] ** (-0.5)
|
||||
)
|
||||
attn = attn.transpose(0, 2, 1, 3).reshape(B, L + S, -1)
|
||||
attn = _attention(q, k, v, pe)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -195,11 +254,9 @@ class SingleStreamBlock(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.mlp_act = nn.GELU(approx="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, True, base=10000)
|
||||
|
||||
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
|
||||
B, L, _ = x.shape
|
||||
H = self.num_heads
|
||||
@@ -218,10 +275,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k)
|
||||
|
||||
# compute attention
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
|
||||
y = y.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
y = _attention(q, k, v, pe)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
|
||||
|
@@ -10,6 +10,7 @@ from .layers import (
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +41,56 @@ class Flux(nn.Module):
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(
|
||||
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
||||
)
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(
|
||||
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
||||
)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
if params.guidance_embed
|
||||
else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = [
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
|
||||
self.single_blocks = [
|
||||
SingleStreamBlock(
|
||||
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
if k.endswith(".scale"):
|
||||
k = k[:-6] + ".weight"
|
||||
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
|
||||
if f".{seq}." in k:
|
||||
k = k.replace(f".{seq}.", f".{seq}.layers.")
|
||||
break
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -51,4 +102,31 @@ class Flux(nn.Module):
|
||||
y: mx.array,
|
||||
guidance: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
pass
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError(
|
||||
"Didn't get guidance strength for guidance distilled model."
|
||||
)
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = mx.concatenate([txt, img], axis=1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
return img
|
||||
|
311
flux/flux/t5.py
Normal file
311
flux/flux/t5.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
_SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
_ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class T5Config:
|
||||
vocab_size: int
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
relative_attention_num_buckets: int
|
||||
d_kv: int
|
||||
d_model: int
|
||||
feed_forward_proj: str
|
||||
tie_word_embeddings: bool
|
||||
|
||||
d_ff: Optional[int] = None
|
||||
num_decoder_layers: Optional[int] = None
|
||||
relative_attention_max_distance: int = 128
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return cls(
|
||||
vocab_size=config["vocab_size"],
|
||||
num_layers=config["num_layers"],
|
||||
num_heads=config["num_heads"],
|
||||
relative_attention_num_buckets=config["relative_attention_num_buckets"],
|
||||
d_kv=config["d_kv"],
|
||||
d_model=config["d_model"],
|
||||
feed_forward_proj=config["feed_forward_proj"],
|
||||
tie_word_embeddings=config["tie_word_embeddings"],
|
||||
d_ff=config.get("d_ff", 4 * config["d_model"]),
|
||||
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
|
||||
relative_attention_max_distance=config.get(
|
||||
"relative_attention_max_distance", 128
|
||||
),
|
||||
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
|
||||
)
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: T5Config, bidirectional: bool):
|
||||
self.bidirectional = bidirectional
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
|
||||
num_buckets = num_buckets // 2 if bidirectional else num_buckets
|
||||
max_exact = num_buckets // 2
|
||||
|
||||
abspos = rpos.abs()
|
||||
is_small = abspos < max_exact
|
||||
|
||||
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
|
||||
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
|
||||
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
|
||||
|
||||
buckets = mx.where(is_small, rpos, buckets_large)
|
||||
if bidirectional:
|
||||
buckets = buckets + (rpos > 0) * num_buckets
|
||||
else:
|
||||
buckets = buckets * (rpos < 0)
|
||||
|
||||
return buckets
|
||||
|
||||
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(offset, query_length)[:, None]
|
||||
memory_position = mx.arange(key_length)[None, :]
|
||||
|
||||
# shape (query_length, key_length)
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
|
||||
# shape (query_length, key_length, num_heads)
|
||||
values = self.embeddings(relative_position_bucket)
|
||||
|
||||
# shape (num_heads, query_length, key_length)
|
||||
return values.transpose(2, 0, 1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
queries: mx.array,
|
||||
keys: mx.array,
|
||||
values: mx.array,
|
||||
mask: Optional[mx.array],
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
values_hat = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=1.0
|
||||
)
|
||||
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.gated = config.feed_forward_proj.startswith("gated")
|
||||
if self.gated:
|
||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
else:
|
||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||
if activation == "relu":
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
self.act = nn.gelu
|
||||
elif activation == "silu":
|
||||
self.act = nn.silu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
def __call__(self, x):
|
||||
if self.gated:
|
||||
hidden_act = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_act * hidden_linear
|
||||
else:
|
||||
x = self.act(self.wi(x))
|
||||
return self.wo(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y, _ = self.attention(y, y, y, mask=mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.dense(y)
|
||||
return x + y
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=pos_bias)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
):
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y, _ = self.cross_attention(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln3(x)
|
||||
y = self.dense(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||
if cache is not None:
|
||||
offset = cache[0][0].shape[3]
|
||||
else:
|
||||
offset = 0
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
T = offset + x.shape[1]
|
||||
pos_bias = self.relative_attention_bias(T, T, offset=offset)
|
||||
if mask is not None:
|
||||
mask += pos_bias
|
||||
else:
|
||||
mask = pos_bias
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
|
||||
x = self.ln(x)
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(inputs)
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
for old, new in _SHARED_REPLACEMENT_PATTERNS:
|
||||
k = k.replace(old, new)
|
||||
if k.startswith("encoder."):
|
||||
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
|
||||
k = k.replace(old, new)
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def __call__(self, inputs: mx.array):
|
||||
return self.encoder(self.wte(inputs))
|
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
@@ -6,7 +7,9 @@ import mlx.core as mx
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .clip import CLIPTextModel, CLIPTextModelConfig
|
||||
from .model import Flux, FluxParams
|
||||
from .t5 import T5Config, T5Encoder
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -136,3 +139,49 @@ def load_ae(name: str, hf_download: bool = True):
|
||||
ae.load_weights(list(weights.items()))
|
||||
|
||||
return ae
|
||||
|
||||
|
||||
def load_t5(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
|
||||
with open(config_path) as f:
|
||||
config = T5Config.from_dict(json.load(f))
|
||||
|
||||
# Make the T5 model
|
||||
t5 = T5Encoder(config)
|
||||
|
||||
# Load the weights
|
||||
model_index = hf_hub_download(
|
||||
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
|
||||
)
|
||||
weight_files = set()
|
||||
with open(model_index) as f:
|
||||
for _, w in json.load(f)["weight_map"].items():
|
||||
weight_files.add(w)
|
||||
weights = {}
|
||||
for w in weight_files:
|
||||
w = f"text_encoder_2/{w}"
|
||||
w = hf_hub_download(configs[name].repo_id, w)
|
||||
weights.update(mx.load(w))
|
||||
weights = t5.sanitize(weights)
|
||||
t5.load_weights(list(weights.items()))
|
||||
|
||||
return t5
|
||||
|
||||
|
||||
def load_clip(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
|
||||
with open(config_path) as f:
|
||||
config = CLIPTextModelConfig.from_dict(json.load(f))
|
||||
|
||||
# Make the clip text encoder
|
||||
clip = CLIPTextModel(config)
|
||||
|
||||
# Load the weights
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = clip.sanitize(weights)
|
||||
clip.load_weights(list(weights.items()))
|
||||
|
||||
return clip
|
||||
|
Reference in New Issue
Block a user