Working clip, t5 and flux model

This commit is contained in:
Angelos Katharopoulos
2024-09-27 13:37:02 -07:00
parent ed17f815f5
commit 63932d777c
5 changed files with 670 additions and 26 deletions

152
flux/flux/clip.py Normal file
View File

@@ -0,0 +1,152 @@
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return cls(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
hidden_act=config["hidden_act"],
)
@dataclass
class CLIPOutput:
# The last_hidden_state indexed at the EOS token and possibly projected if
# the model has a projection layer
pooled_output: Optional[mx.array] = None
# The full sequence output of the transformer after the final layernorm
last_hidden_state: Optional[mx.array] = None
# A list of hidden states corresponding to the outputs of the transformer layers
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def sanitize(self, weights):
new_weights = {}
for key, w in weights.items():
# Remove prefixes
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
new_weights[key] = w
return new_weights
def __call__(self, x):
# Extract some shapes
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
hidden_states = []
for l in self.layers:
x = l(x, mask)
hidden_states.append(x)
# Apply the final layernorm and return
x = self.final_layer_norm(x)
last_hidden_state = x
# Select the EOS token
pooled_output = x[mx.arange(len(x)), eos_tokens]
return CLIPOutput(
pooled_output=pooled_output,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
)

View File

@@ -1,10 +1,78 @@
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from functools import partial
from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
def _rope(pos: mx.array, dim: int, theta: float):
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
omega = 1.0 / (theta**scale)
x = pos[..., None] * omega
cosx = mx.cos(x)
sinx = mx.sin(x)
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
pe = pe.reshape(*pe.shape[:-1], 2, 2)
return pe
@partial(mx.compile, shapeless=True)
def _ab_plus_cd(a, b, c, d):
return a * b + c * d
def _apply_rope(x, pe):
s = x.shape
x = x.reshape(*s[:-1], -1, 1, 2)
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
return x.reshape(s)
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
B, H, L, D = q.shape
q = _apply_rope(q, pe)
k = _apply_rope(k, pe)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
def timestep_embedding(
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
):
half = dim // 2
freqs = mx.arange(0, half, dtype=mx.float32) / half
freqs = freqs * (-math.log(max_period))
freqs = mx.exp(freqs)
x = (time_factor * t)[:, None] * freqs[None]
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
return x.astype(t.dtype)
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids: mx.array):
n_axes = ids.shape[-1]
pe = mx.concatenate(
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
axis=-3,
)
return pe[:, None]
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int): def __init__(self, in_dim: int, hidden_dim: int):
super().__init__() super().__init__()
@@ -34,7 +102,6 @@ class SelfAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim) self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
self.rope = nn.RoPE(head_dim, True, base=10000)
def __call__(self, x: mx.array, pe: mx.array) -> mx.array: def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
H = self.num_heads H = self.num_heads
@@ -45,10 +112,7 @@ class SelfAttention(nn.Module):
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k) q, k = self.norm(q, k)
q = self.rope(q) x = _attention(q, k, v, pe)
k = self.rope(k)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
x = x.transpose(0, 2, 1, 3).reshape(B, L, -1)
x = self.proj(x) x = self.proj(x)
return x return x
@@ -95,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_mlp = nn.Sequential( self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"), nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
) )
self.txt_mod = Modulation(hidden_size, double=True) self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_attn = SelfAttention( self.txt_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
) )
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential( self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"), nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
) )
@@ -130,7 +194,7 @@ class DoubleStreamBlock(nn.Module):
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_q, img_k = self.norm(img_q, img_k) img_q, img_k = self.img_attn.norm(img_q, img_k)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
@@ -140,19 +204,14 @@ class DoubleStreamBlock(nn.Module):
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3) txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3) txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3) txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_q, txt_k = self.norm(txt_q, txt_k) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
# run actual attention # run actual attention
q = mx.concatenate([txt_q, img_q], axis=2) q = mx.concatenate([txt_q, img_q], axis=2)
k = mx.concatenate([txt_k, img_k], axis=2) k = mx.concatenate([txt_k, img_k], axis=2)
v = mx.concatenate([txt_v, img_v], axis=2) v = mx.concatenate([txt_v, img_v], axis=2)
q = self.img_attn.rope(q) attn = _attention(q, k, v, pe)
k = self.img_attn.rope(k)
attn = mx.fast.scaled_dot_product_attention(
q, k, v, scale=q.shape[-1] ** (-0.5)
)
attn = attn.transpose(0, 2, 1, 3).reshape(B, L + S, -1)
txt_attn, img_attn = mx.split(attn, [S], axis=1) txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks # calculate the img bloks
@@ -195,11 +254,9 @@ class SingleStreamBlock(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approx="tanh")
self.modulation = Modulation(hidden_size, double=False) self.modulation = Modulation(hidden_size, double=False)
self.rope = nn.RoPE(head_dim, True, base=10000)
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array): def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
B, L, _ = x.shape B, L, _ = x.shape
H = self.num_heads H = self.num_heads
@@ -218,10 +275,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k) q, k = self.norm(q, k)
# compute attention # compute attention
q = self.rope(q) y = _attention(q, k, v, pe)
k = self.rope(k)
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
y = y.transpose(0, 2, 1, 3).reshape(B, L, -1)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2)) y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))

View File

@@ -10,6 +10,7 @@ from .layers import (
LastLayer, LastLayer,
MLPEmbedder, MLPEmbedder,
SingleStreamBlock, SingleStreamBlock,
timestep_embedding,
) )
@@ -40,6 +41,56 @@ class Flux(nn.Module):
raise ValueError( raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
) )
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = [
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
self.single_blocks = [
SingleStreamBlock(
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
)
for _ in range(params.depth_single_blocks)
]
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if k.endswith(".scale"):
k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
if f".{seq}." in k:
k = k.replace(f".{seq}.", f".{seq}.layers.")
break
new_weights[k] = w
return new_weights
def __call__( def __call__(
self, self,
@@ -51,4 +102,31 @@ class Flux(nn.Module):
y: mx.array, y: mx.array,
guidance: Optional[mx.array] = None, guidance: Optional[mx.array] = None,
) -> mx.array: ) -> mx.array:
pass if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = mx.concatenate([txt_ids, img_ids], axis=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = mx.concatenate([txt, img], axis=1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec)
return img

311
flux/flux/t5.py Normal file
View File

@@ -0,0 +1,311 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
_SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
_ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
@dataclass
class T5Config:
vocab_size: int
num_layers: int
num_heads: int
relative_attention_num_buckets: int
d_kv: int
d_model: int
feed_forward_proj: str
tie_word_embeddings: bool
d_ff: Optional[int] = None
num_decoder_layers: Optional[int] = None
relative_attention_max_distance: int = 128
layer_norm_epsilon: float = 1e-6
@classmethod
def from_dict(cls, config):
return cls(
vocab_size=config["vocab_size"],
num_layers=config["num_layers"],
num_heads=config["num_heads"],
relative_attention_num_buckets=config["relative_attention_num_buckets"],
d_kv=config["d_kv"],
d_model=config["d_model"],
feed_forward_proj=config["feed_forward_proj"],
tie_word_embeddings=config["tie_word_embeddings"],
d_ff=config.get("d_ff", 4 * config["d_model"]),
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
relative_attention_max_distance=config.get(
"relative_attention_max_distance", 128
),
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
)
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
num_buckets = num_buckets // 2 if bidirectional else num_buckets
max_exact = num_buckets // 2
abspos = rpos.abs()
is_small = abspos < max_exact
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
buckets = mx.where(is_small, rpos, buckets_large)
if bidirectional:
buckets = buckets + (rpos > 0) * num_buckets
else:
buckets = buckets * (rpos < 0)
return buckets
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
values_hat = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=1.0
)
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(
self,
x: mx.array,
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
y, _ = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.dense(y)
x = x + y
return x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, mask, memory_mask, cache=None):
if cache is not None:
offset = cache[0][0].shape[3]
else:
offset = 0
cache = [None] * len(self.layers)
T = offset + x.shape[1]
pos_bias = self.relative_attention_bias(T, T, offset=offset)
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
x = self.ln(x)
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
return self.linear(inputs)
class T5Encoder(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
for old, new in _SHARED_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
if k.startswith("encoder."):
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
new_weights[k] = w
return new_weights
def __call__(self, inputs: mx.array):
return self.encoder(self.wte(inputs))

View File

@@ -1,3 +1,4 @@
import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
@@ -6,7 +7,9 @@ import mlx.core as mx
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from .autoencoder import AutoEncoder, AutoEncoderParams from .autoencoder import AutoEncoder, AutoEncoderParams
from .clip import CLIPTextModel, CLIPTextModelConfig
from .model import Flux, FluxParams from .model import Flux, FluxParams
from .t5 import T5Config, T5Encoder
@dataclass @dataclass
@@ -136,3 +139,49 @@ def load_ae(name: str, hf_download: bool = True):
ae.load_weights(list(weights.items())) ae.load_weights(list(weights.items()))
return ae return ae
def load_t5(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
with open(config_path) as f:
config = T5Config.from_dict(json.load(f))
# Make the T5 model
t5 = T5Encoder(config)
# Load the weights
model_index = hf_hub_download(
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
)
weight_files = set()
with open(model_index) as f:
for _, w in json.load(f)["weight_map"].items():
weight_files.add(w)
weights = {}
for w in weight_files:
w = f"text_encoder_2/{w}"
w = hf_hub_download(configs[name].repo_id, w)
weights.update(mx.load(w))
weights = t5.sanitize(weights)
t5.load_weights(list(weights.items()))
return t5
def load_clip(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
with open(config_path) as f:
config = CLIPTextModelConfig.from_dict(json.load(f))
# Make the clip text encoder
clip = CLIPTextModel(config)
# Load the weights
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
weights = mx.load(ckpt_path)
weights = clip.sanitize(weights)
clip.load_weights(list(weights.items()))
return clip