mlx-examples/flux/flux/layers.py
2024-10-11 21:17:41 -07:00

303 lines
9.8 KiB
Python

# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
def _rope(pos: mx.array, dim: int, theta: float):
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
omega = 1.0 / (theta**scale)
x = pos[..., None] * omega
cosx = mx.cos(x)
sinx = mx.sin(x)
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
pe = pe.reshape(*pe.shape[:-1], 2, 2)
return pe
@partial(mx.compile, shapeless=True)
def _ab_plus_cd(a, b, c, d):
return a * b + c * d
def _apply_rope(x, pe):
s = x.shape
x = x.reshape(*s[:-1], -1, 1, 2)
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
return x.reshape(s)
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
B, H, L, D = q.shape
q = _apply_rope(q, pe)
k = _apply_rope(k, pe)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
def timestep_embedding(
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
):
half = dim // 2
freqs = mx.arange(0, half, dtype=mx.float32) / half
freqs = freqs * (-math.log(max_period))
freqs = mx.exp(freqs)
x = (time_factor * t)[:, None] * freqs[None]
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
return x.astype(t.dtype)
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids: mx.array):
n_axes = ids.shape[-1]
pe = mx.concatenate(
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
axis=-3,
)
return pe[:, None]
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.out_layer(nn.silu(self.in_layer(x)))
class QKNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = nn.RMSNorm(dim)
self.key_norm = nn.RMSNorm(dim)
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
return self.query_norm(q), self.key_norm(k)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
H = self.num_heads
B, L, _ = x.shape
qkv = self.qkv(x)
q, k, v = mx.split(qkv, 3, axis=-1)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
x = _attention(q, k, v, pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: mx.array
scale: mx.array
gate: mx.array
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
x = self.lin(nn.silu(x))
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
mod1 = ModulationOut(*xs[:3])
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
return mod1, mod2
class DoubleStreamBlock(nn.Module):
def __init__(
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
B, L, _ = img.shape
_, S, _ = txt.shape
H = self.num_heads
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_q, img_k = self.img_attn.norm(img_q, img_k)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
# run actual attention
q = mx.concatenate([txt_q, img_q], axis=2)
k = mx.concatenate([txt_k, img_k], axis=2)
v = mx.concatenate([txt_v, img_v], axis=2)
attn = _attention(q, k, v, pe)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approx="tanh")
self.modulation = Modulation(hidden_size, double=False)
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
B, L, _ = x.shape
H = self.num_heads
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
q, k, v, mlp = mx.split(
self.linear1(x_mod),
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
axis=-1,
)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
# compute attention
y = _attention(q, k, v, pe)
# compute activation in mlp stream, cat again and run second linear layer
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
return x + mod.gate * y
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def __call__(self, x: mx.array, vec: mx.array):
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x