mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 12:13:25 +08:00
303 lines
9.8 KiB
Python
303 lines
9.8 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from typing import List, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
def _rope(pos: mx.array, dim: int, theta: float):
|
|
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
|
omega = 1.0 / (theta**scale)
|
|
x = pos[..., None] * omega
|
|
cosx = mx.cos(x)
|
|
sinx = mx.sin(x)
|
|
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
|
|
pe = pe.reshape(*pe.shape[:-1], 2, 2)
|
|
|
|
return pe
|
|
|
|
|
|
@partial(mx.compile, shapeless=True)
|
|
def _ab_plus_cd(a, b, c, d):
|
|
return a * b + c * d
|
|
|
|
|
|
def _apply_rope(x, pe):
|
|
s = x.shape
|
|
x = x.reshape(*s[:-1], -1, 1, 2)
|
|
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
|
|
return x.reshape(s)
|
|
|
|
|
|
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
|
|
B, H, L, D = q.shape
|
|
|
|
q = _apply_rope(q, pe)
|
|
k = _apply_rope(k, pe)
|
|
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
|
|
|
|
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
|
|
|
|
def timestep_embedding(
|
|
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
|
|
):
|
|
half = dim // 2
|
|
freqs = mx.arange(0, half, dtype=mx.float32) / half
|
|
freqs = freqs * (-math.log(max_period))
|
|
freqs = mx.exp(freqs)
|
|
|
|
x = (time_factor * t)[:, None] * freqs[None]
|
|
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
|
|
|
|
return x.astype(t.dtype)
|
|
|
|
|
|
class EmbedND(nn.Module):
|
|
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.theta = theta
|
|
self.axes_dim = axes_dim
|
|
|
|
def __call__(self, ids: mx.array):
|
|
n_axes = ids.shape[-1]
|
|
pe = mx.concatenate(
|
|
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
|
axis=-3,
|
|
)
|
|
|
|
return pe[:, None]
|
|
|
|
|
|
class MLPEmbedder(nn.Module):
|
|
def __init__(self, in_dim: int, hidden_dim: int):
|
|
super().__init__()
|
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
return self.out_layer(nn.silu(self.in_layer(x)))
|
|
|
|
|
|
class QKNorm(nn.Module):
|
|
def __init__(self, dim: int):
|
|
super().__init__()
|
|
self.query_norm = nn.RMSNorm(dim)
|
|
self.key_norm = nn.RMSNorm(dim)
|
|
|
|
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
|
|
return self.query_norm(q), self.key_norm(k)
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.norm = QKNorm(head_dim)
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
|
|
H = self.num_heads
|
|
B, L, _ = x.shape
|
|
qkv = self.qkv(x)
|
|
q, k, v = mx.split(qkv, 3, axis=-1)
|
|
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
q, k = self.norm(q, k)
|
|
x = _attention(q, k, v, pe)
|
|
x = self.proj(x)
|
|
return x
|
|
|
|
|
|
@dataclass
|
|
class ModulationOut:
|
|
shift: mx.array
|
|
scale: mx.array
|
|
gate: mx.array
|
|
|
|
|
|
class Modulation(nn.Module):
|
|
def __init__(self, dim: int, double: bool):
|
|
super().__init__()
|
|
self.is_double = double
|
|
self.multiplier = 6 if double else 3
|
|
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
|
|
|
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
|
x = self.lin(nn.silu(x))
|
|
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
|
|
|
|
mod1 = ModulationOut(*xs[:3])
|
|
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
|
|
|
return mod1, mod2
|
|
|
|
|
|
class DoubleStreamBlock(nn.Module):
|
|
def __init__(
|
|
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
|
):
|
|
super().__init__()
|
|
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
self.num_heads = num_heads
|
|
self.hidden_size = hidden_size
|
|
self.img_mod = Modulation(hidden_size, double=True)
|
|
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
self.img_attn = SelfAttention(
|
|
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
|
)
|
|
|
|
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
self.img_mlp = nn.Sequential(
|
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
|
nn.GELU(approx="tanh"),
|
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
|
)
|
|
|
|
self.txt_mod = Modulation(hidden_size, double=True)
|
|
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
self.txt_attn = SelfAttention(
|
|
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
|
)
|
|
|
|
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
self.txt_mlp = nn.Sequential(
|
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
|
nn.GELU(approx="tanh"),
|
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
|
)
|
|
|
|
def __call__(
|
|
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
|
) -> Tuple[mx.array, mx.array]:
|
|
B, L, _ = img.shape
|
|
_, S, _ = txt.shape
|
|
H = self.num_heads
|
|
|
|
img_mod1, img_mod2 = self.img_mod(vec)
|
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
|
|
|
# prepare image for attention
|
|
img_modulated = self.img_norm1(img)
|
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
img_qkv = self.img_attn.qkv(img_modulated)
|
|
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
|
|
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
|
|
|
# prepare txt for attention
|
|
txt_modulated = self.txt_norm1(txt)
|
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
|
|
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
|
|
|
# run actual attention
|
|
q = mx.concatenate([txt_q, img_q], axis=2)
|
|
k = mx.concatenate([txt_k, img_k], axis=2)
|
|
v = mx.concatenate([txt_v, img_v], axis=2)
|
|
|
|
attn = _attention(q, k, v, pe)
|
|
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
|
|
|
# calculate the img bloks
|
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
|
img = img + img_mod2.gate * self.img_mlp(
|
|
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
|
)
|
|
|
|
# calculate the txt bloks
|
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
|
txt = txt + txt_mod2.gate * self.txt_mlp(
|
|
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
|
)
|
|
|
|
return img, txt
|
|
|
|
|
|
class SingleStreamBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qk_scale: Optional[float] = None,
|
|
):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_size
|
|
self.num_heads = num_heads
|
|
head_dim = hidden_size // num_heads
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
# qkv and mlp_in
|
|
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
|
# proj and mlp_out
|
|
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
|
|
|
self.norm = QKNorm(head_dim)
|
|
|
|
self.hidden_size = hidden_size
|
|
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
|
|
self.mlp_act = nn.GELU(approx="tanh")
|
|
self.modulation = Modulation(hidden_size, double=False)
|
|
|
|
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
|
|
B, L, _ = x.shape
|
|
H = self.num_heads
|
|
|
|
mod, _ = self.modulation(vec)
|
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
|
|
|
q, k, v, mlp = mx.split(
|
|
self.linear1(x_mod),
|
|
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
|
|
axis=-1,
|
|
)
|
|
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
q, k = self.norm(q, k)
|
|
|
|
# compute attention
|
|
y = _attention(q, k, v, pe)
|
|
|
|
# compute activation in mlp stream, cat again and run second linear layer
|
|
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
|
|
return x + mod.gate * y
|
|
|
|
|
|
class LastLayer(nn.Module):
|
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
|
super().__init__()
|
|
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
self.linear = nn.Linear(
|
|
hidden_size, patch_size * patch_size * out_channels, bias=True
|
|
)
|
|
self.adaLN_modulation = nn.Sequential(
|
|
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
)
|
|
|
|
def __call__(self, x: mx.array, vec: mx.array):
|
|
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
|
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
|
x = self.linear(x)
|
|
return x
|