Flux implementation in examples

This commit is contained in:
Angelos Katharopoulos 2024-09-25 00:58:30 -07:00
parent f61f4b5cf1
commit ed17f815f5
5 changed files with 798 additions and 0 deletions

0
flux/flux/__init__.py Normal file
View File

360
flux/flux/autoencoder.py Normal file
View File

@ -0,0 +1,360 @@
from dataclasses import dataclass
from typing import List
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.upsample import upsample_nearest
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: List[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.q = nn.Linear(in_channels, in_channels)
self.k = nn.Linear(in_channels, in_channels)
self.v = nn.Linear(in_channels, in_channels)
self.proj_out = nn.Linear(in_channels, in_channels)
def __call__(self, x: mx.array) -> mx.array:
B, H, W, C = x.shape
y = x.reshape(B, 1, -1, C)
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
y = self.proj_out(y)
return x + y.reshape(B, H, W, C)
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = nn.GroupNorm(
num_groups=32,
dims=out_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x):
h = x
h = self.norm1(h)
h = nn.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = nn.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array):
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
x = upsample_nearest(x, (2, 2))
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = []
block_in = self.ch
for i_level in range(self.num_resolutions):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = {}
down["block"] = block
down["attn"] = attn
if i_level != self.num_resolutions - 1:
down["downsample"] = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level]["block"][i_block](hs[-1])
# TODO: Remove the attn
if len(self.down[i_level]["attn"]) > 0:
h = self.down[i_level]["attn"][i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level]["downsample"](hs[-1]))
# middle
h = hs[-1]
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = []
for i_level in reversed(range(self.num_resolutions)):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = {}
up["block"] = block
up["attn"] = attn
if i_level != 0:
up["upsample"] = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def __call__(self, z: mx.array):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level]["block"][i_block](h)
# TODO: Remove the attn
if len(self.up[i_level]["attn"]) > 0:
h = self.up[i_level]["attn"][i_block](h)
if i_level != 0:
h = self.up[i_level]["upsample"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def __call__(self, z: mx.array):
mean, logvar = mx.split(z, 2, axis=self.chunk_dim)
if self.sample:
std = mx.exp(0.5 * logvar)
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
return mean + std * eps
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if w.ndim == 4:
w = w.transpose(0, 2, 3, 1)
w = w.reshape(-1).reshape(w.shape)
if w.shape[1:3] == (1, 1):
w = w.squeeze((1, 2))
new_weights[k] = w
return new_weights
def encode(self, x: mx.array):
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: mx.array):
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def __call__(self, x: mx.array):
return self.decode(self.encode(x))

246
flux/flux/layers.py Normal file
View File

@ -0,0 +1,246 @@
from dataclasses import dataclass
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.out_layer(nn.silu(self.in_layer(x)))
class QKNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = nn.RMSNorm(dim)
self.key_norm = nn.RMSNorm(dim)
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
return self.query_norm(q), self.key_norm(k)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
self.rope = nn.RoPE(head_dim, True, base=10000)
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
H = self.num_heads
B, L, _ = x.shape
qkv = self.qkv(x)
q, k, v = mx.split(qkv, 3, axis=-1)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
q = self.rope(q)
k = self.rope(k)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
x = x.transpose(0, 2, 1, 3).reshape(B, L, -1)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: mx.array
scale: mx.array
gate: mx.array
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
x = self.lin(nn.silu(x))
xs = mx.split(x, self.multiplier, axis=-1)
mod1 = ModulationOut(*xs[:3])
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
return mod1, mod2
class DoubleStreamBlock(nn.Module):
def __init__(
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
B, L, _ = img.shape
_, S, _ = txt.shape
H = self.num_heads
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_q, img_k = self.norm(img_q, img_k)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_q, txt_k = self.norm(txt_q, txt_k)
# run actual attention
q = mx.concatenate([txt_q, img_q], axis=2)
k = mx.concatenate([txt_k, img_k], axis=2)
v = mx.concatenate([txt_v, img_v], axis=2)
q = self.img_attn.rope(q)
k = self.img_attn.rope(k)
attn = mx.fast.scaled_dot_product_attention(
q, k, v, scale=q.shape[-1] ** (-0.5)
)
attn = attn.transpose(0, 2, 1, 3).reshape(B, L + S, -1)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
self.rope = nn.RoPE(head_dim, True, base=10000)
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
B, L, _ = x.shape
H = self.num_heads
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
q, k, v, mlp = mx.split(
self.linear1(x_mod),
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
axis=-1,
)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
# compute attention
q = self.rope(q)
k = self.rope(k)
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
y = y.transpose(0, 2, 1, 3).reshape(B, L, -1)
# compute activation in mlp stream, cat again and run second linear layer
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
return x + mod.gate * y
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def __call__(self, x: mx.array, vec: mx.array):
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

54
flux/flux/model.py Normal file
View File

@ -0,0 +1,54 @@
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
def __call__(
self,
img: mx.array,
img_ids: mx.array,
txt: mx.array,
txt_ids: mx.array,
timesteps: mx.array,
y: mx.array,
guidance: Optional[mx.array] = None,
) -> mx.array:
pass

138
flux/flux/utils.py Normal file
View File

@ -0,0 +1,138 @@
import os
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
from huggingface_hub import hf_hub_download
from .autoencoder import AutoEncoder, AutoEncoderParams
from .model import Flux, FluxParams
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: Optional[str]
ae_path: Optional[str]
repo_id: Optional[str]
repo_flow: Optional[str]
repo_ae: Optional[str]
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def load_flow_model(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ckpt_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
# Make the model
model = Flux(configs[name].params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()))
return model
def load_ae(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ae_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
# Make the autoencoder
ae = AutoEncoder(configs[name].ae_params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = ae.sanitize(weights)
ae.load_weights(list(weights.items()))
return ae