mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Flux implementation in examples
This commit is contained in:
0
flux/flux/__init__.py
Normal file
0
flux/flux/__init__.py
Normal file
360
flux/flux/autoencoder.py
Normal file
360
flux/flux/autoencoder.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.nn.layers.upsample import upsample_nearest
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoEncoderParams:
|
||||||
|
resolution: int
|
||||||
|
in_channels: int
|
||||||
|
ch: int
|
||||||
|
out_ch: int
|
||||||
|
ch_mult: List[int]
|
||||||
|
num_res_blocks: int
|
||||||
|
z_channels: int
|
||||||
|
scale_factor: float
|
||||||
|
shift_factor: float
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(
|
||||||
|
num_groups=32,
|
||||||
|
dims=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True,
|
||||||
|
pytorch_compatible=True,
|
||||||
|
)
|
||||||
|
self.q = nn.Linear(in_channels, in_channels)
|
||||||
|
self.k = nn.Linear(in_channels, in_channels)
|
||||||
|
self.v = nn.Linear(in_channels, in_channels)
|
||||||
|
self.proj_out = nn.Linear(in_channels, in_channels)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
|
||||||
|
y = x.reshape(B, 1, -1, C)
|
||||||
|
y = self.norm(y)
|
||||||
|
q = self.q(y)
|
||||||
|
k = self.k(y)
|
||||||
|
v = self.v(y)
|
||||||
|
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
|
||||||
|
y = self.proj_out(y)
|
||||||
|
|
||||||
|
return x + y.reshape(B, H, W, C)
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(
|
||||||
|
num_groups=32,
|
||||||
|
dims=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True,
|
||||||
|
pytorch_compatible=True,
|
||||||
|
)
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
self.norm2 = nn.GroupNorm(
|
||||||
|
num_groups=32,
|
||||||
|
dims=out_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True,
|
||||||
|
pytorch_compatible=True,
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
self.nin_shortcut = nn.Linear(in_channels, out_channels)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = nn.silu(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = nn.silu(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
x = upsample_nearest(x, (2, 2))
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int,
|
||||||
|
in_channels: int,
|
||||||
|
ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = []
|
||||||
|
block_in = self.ch
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = []
|
||||||
|
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
down = {}
|
||||||
|
down["block"] = block
|
||||||
|
down["attn"] = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down["downsample"] = Downsample(block_in)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = {}
|
||||||
|
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid["attn_1"] = AttnBlock(block_in)
|
||||||
|
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(
|
||||||
|
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(
|
||||||
|
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level]["block"][i_block](hs[-1])
|
||||||
|
|
||||||
|
# TODO: Remove the attn
|
||||||
|
if len(self.down[i_level]["attn"]) > 0:
|
||||||
|
h = self.down[i_level]["attn"][i_block](h)
|
||||||
|
|
||||||
|
hs.append(h)
|
||||||
|
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
hs.append(self.down[i_level]["downsample"](hs[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid["block_1"](h)
|
||||||
|
h = self.mid["attn_1"](h)
|
||||||
|
h = self.mid["block_2"](h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nn.silu(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch: int,
|
||||||
|
out_ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
in_channels: int,
|
||||||
|
resolution: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = {}
|
||||||
|
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid["attn_1"] = AttnBlock(block_in)
|
||||||
|
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = []
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = []
|
||||||
|
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
||||||
|
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks + 1):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
up = {}
|
||||||
|
up["block"] = block
|
||||||
|
up["attn"] = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up["upsample"] = Upsample(block_in)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(
|
||||||
|
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def __call__(self, z: mx.array):
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid["block_1"](h)
|
||||||
|
h = self.mid["attn_1"](h)
|
||||||
|
h = self.mid["block_2"](h)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level]["block"][i_block](h)
|
||||||
|
|
||||||
|
# TODO: Remove the attn
|
||||||
|
if len(self.up[i_level]["attn"]) > 0:
|
||||||
|
h = self.up[i_level]["attn"][i_block](h)
|
||||||
|
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level]["upsample"](h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nn.silu(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussian(nn.Module):
|
||||||
|
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.sample = sample
|
||||||
|
self.chunk_dim = chunk_dim
|
||||||
|
|
||||||
|
def __call__(self, z: mx.array):
|
||||||
|
mean, logvar = mx.split(z, 2, axis=self.chunk_dim)
|
||||||
|
if self.sample:
|
||||||
|
std = mx.exp(0.5 * logvar)
|
||||||
|
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
|
||||||
|
return mean + std * eps
|
||||||
|
else:
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
class AutoEncoder(nn.Module):
|
||||||
|
def __init__(self, params: AutoEncoderParams):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
out_ch=params.out_ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
self.reg = DiagonalGaussian()
|
||||||
|
|
||||||
|
self.scale_factor = params.scale_factor
|
||||||
|
self.shift_factor = params.shift_factor
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
new_weights = {}
|
||||||
|
for k, w in weights.items():
|
||||||
|
if w.ndim == 4:
|
||||||
|
w = w.transpose(0, 2, 3, 1)
|
||||||
|
w = w.reshape(-1).reshape(w.shape)
|
||||||
|
if w.shape[1:3] == (1, 1):
|
||||||
|
w = w.squeeze((1, 2))
|
||||||
|
new_weights[k] = w
|
||||||
|
return new_weights
|
||||||
|
|
||||||
|
def encode(self, x: mx.array):
|
||||||
|
z = self.reg(self.encoder(x))
|
||||||
|
z = self.scale_factor * (z - self.shift_factor)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def decode(self, z: mx.array):
|
||||||
|
z = z / self.scale_factor + self.shift_factor
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
return self.decode(self.encode(x))
|
246
flux/flux/layers.py
Normal file
246
flux/flux/layers.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class MLPEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||||
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.out_layer(nn.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class QKNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.query_norm = nn.RMSNorm(dim)
|
||||||
|
self.key_norm = nn.RMSNorm(dim)
|
||||||
|
|
||||||
|
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
|
||||||
|
return self.query_norm(q), self.key_norm(k)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.rope = nn.RoPE(head_dim, True, base=10000)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
|
||||||
|
H = self.num_heads
|
||||||
|
B, L, _ = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||||
|
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
q, k = self.norm(q, k)
|
||||||
|
q = self.rope(q)
|
||||||
|
k = self.rope(k)
|
||||||
|
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
|
||||||
|
x = x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModulationOut:
|
||||||
|
shift: mx.array
|
||||||
|
scale: mx.array
|
||||||
|
gate: mx.array
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, dim: int, double: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.is_double = double
|
||||||
|
self.multiplier = 6 if double else 3
|
||||||
|
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||||
|
x = self.lin(nn.silu(x))
|
||||||
|
xs = mx.split(x, self.multiplier, axis=-1)
|
||||||
|
|
||||||
|
mod1 = ModulationOut(*xs[:3])
|
||||||
|
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
||||||
|
|
||||||
|
return mod1, mod2
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleStreamBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.img_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||||
|
self.img_attn = SelfAttention(
|
||||||
|
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_attn = SelfAttention(
|
||||||
|
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
B, L, _ = img.shape
|
||||||
|
_, S, _ = txt.shape
|
||||||
|
H = self.num_heads
|
||||||
|
|
||||||
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
|
# prepare image for attention
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
|
||||||
|
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
img_q, img_k = self.norm(img_q, img_k)
|
||||||
|
|
||||||
|
# prepare txt for attention
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
|
||||||
|
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
txt_q, txt_k = self.norm(txt_q, txt_k)
|
||||||
|
|
||||||
|
# run actual attention
|
||||||
|
q = mx.concatenate([txt_q, img_q], axis=2)
|
||||||
|
k = mx.concatenate([txt_k, img_k], axis=2)
|
||||||
|
v = mx.concatenate([txt_v, img_v], axis=2)
|
||||||
|
|
||||||
|
q = self.img_attn.rope(q)
|
||||||
|
k = self.img_attn.rope(k)
|
||||||
|
attn = mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=q.shape[-1] ** (-0.5)
|
||||||
|
)
|
||||||
|
attn = attn.transpose(0, 2, 1, 3).reshape(B, L + S, -1)
|
||||||
|
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||||
|
|
||||||
|
# calculate the img bloks
|
||||||
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
|
img = img + img_mod2.gate * self.img_mlp(
|
||||||
|
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate the txt bloks
|
||||||
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
|
txt = txt + txt_mod2.gate * self.txt_mlp(
|
||||||
|
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
||||||
|
)
|
||||||
|
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qk_scale: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
# qkv and mlp_in
|
||||||
|
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||||
|
# proj and mlp_out
|
||||||
|
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
self.modulation = Modulation(hidden_size, double=False)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(head_dim, True, base=10000)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
|
||||||
|
B, L, _ = x.shape
|
||||||
|
H = self.num_heads
|
||||||
|
|
||||||
|
mod, _ = self.modulation(vec)
|
||||||
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
|
|
||||||
|
q, k, v, mlp = mx.split(
|
||||||
|
self.linear1(x_mod),
|
||||||
|
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||||
|
q, k = self.norm(q, k)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
q = self.rope(q)
|
||||||
|
k = self.rope(k)
|
||||||
|
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5))
|
||||||
|
y = y.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
|
||||||
|
return x + mod.gate * y
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(
|
||||||
|
hidden_size, patch_size * patch_size * out_channels, bias=True
|
||||||
|
)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, vec: mx.array):
|
||||||
|
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
|
||||||
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
54
flux/flux/model.py
Normal file
54
flux/flux/model.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
EmbedND,
|
||||||
|
LastLayer,
|
||||||
|
MLPEmbedder,
|
||||||
|
SingleStreamBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FluxParams:
|
||||||
|
in_channels: int
|
||||||
|
vec_in_dim: int
|
||||||
|
context_in_dim: int
|
||||||
|
hidden_size: int
|
||||||
|
mlp_ratio: float
|
||||||
|
num_heads: int
|
||||||
|
depth: int
|
||||||
|
depth_single_blocks: int
|
||||||
|
axes_dim: list[int]
|
||||||
|
theta: int
|
||||||
|
qkv_bias: bool
|
||||||
|
guidance_embed: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Flux(nn.Module):
|
||||||
|
def __init__(self, params: FluxParams):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.params = params
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = self.in_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
img: mx.array,
|
||||||
|
img_ids: mx.array,
|
||||||
|
txt: mx.array,
|
||||||
|
txt_ids: mx.array,
|
||||||
|
timesteps: mx.array,
|
||||||
|
y: mx.array,
|
||||||
|
guidance: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
pass
|
138
flux/flux/utils.py
Normal file
138
flux/flux/utils.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from .autoencoder import AutoEncoder, AutoEncoderParams
|
||||||
|
from .model import Flux, FluxParams
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelSpec:
|
||||||
|
params: FluxParams
|
||||||
|
ae_params: AutoEncoderParams
|
||||||
|
ckpt_path: Optional[str]
|
||||||
|
ae_path: Optional[str]
|
||||||
|
repo_id: Optional[str]
|
||||||
|
repo_flow: Optional[str]
|
||||||
|
repo_ae: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
"flux-dev": ModelSpec(
|
||||||
|
repo_id="black-forest-labs/FLUX.1-dev",
|
||||||
|
repo_flow="flux1-dev.safetensors",
|
||||||
|
repo_ae="ae.safetensors",
|
||||||
|
ckpt_path=os.getenv("FLUX_DEV"),
|
||||||
|
params=FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=True,
|
||||||
|
),
|
||||||
|
ae_path=os.getenv("AE"),
|
||||||
|
ae_params=AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"flux-schnell": ModelSpec(
|
||||||
|
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||||
|
repo_flow="flux1-schnell.safetensors",
|
||||||
|
repo_ae="ae.safetensors",
|
||||||
|
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||||
|
params=FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=False,
|
||||||
|
),
|
||||||
|
ae_path=os.getenv("AE"),
|
||||||
|
ae_params=AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_flow_model(name: str, hf_download: bool = True):
|
||||||
|
# Get the safetensors file to load
|
||||||
|
ckpt_path = configs[name].ckpt_path
|
||||||
|
|
||||||
|
# Download if needed
|
||||||
|
if (
|
||||||
|
ckpt_path is None
|
||||||
|
and configs[name].repo_id is not None
|
||||||
|
and configs[name].repo_flow is not None
|
||||||
|
and hf_download
|
||||||
|
):
|
||||||
|
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
||||||
|
|
||||||
|
# Make the model
|
||||||
|
model = Flux(configs[name].params)
|
||||||
|
|
||||||
|
# Load the checkpoint if needed
|
||||||
|
if ckpt_path is not None:
|
||||||
|
weights = mx.load(ckpt_path)
|
||||||
|
weights = model.sanitize(weights)
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_ae(name: str, hf_download: bool = True):
|
||||||
|
# Get the safetensors file to load
|
||||||
|
ckpt_path = configs[name].ae_path
|
||||||
|
|
||||||
|
# Download if needed
|
||||||
|
if (
|
||||||
|
ckpt_path is None
|
||||||
|
and configs[name].repo_id is not None
|
||||||
|
and configs[name].repo_ae is not None
|
||||||
|
and hf_download
|
||||||
|
):
|
||||||
|
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
||||||
|
|
||||||
|
# Make the autoencoder
|
||||||
|
ae = AutoEncoder(configs[name].ae_params)
|
||||||
|
|
||||||
|
# Load the checkpoint if needed
|
||||||
|
if ckpt_path is not None:
|
||||||
|
weights = mx.load(ckpt_path)
|
||||||
|
weights = ae.sanitize(weights)
|
||||||
|
ae.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
return ae
|
Reference in New Issue
Block a user