From ed17f815f50ad623949ca2a99679534e4b74fbf4 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Sep 2024 00:58:30 -0700 Subject: [PATCH] Flux implementation in examples --- flux/flux/__init__.py | 0 flux/flux/autoencoder.py | 360 +++++++++++++++++++++++++++++++++++++++ flux/flux/layers.py | 246 ++++++++++++++++++++++++++ flux/flux/model.py | 54 ++++++ flux/flux/utils.py | 138 +++++++++++++++ 5 files changed, 798 insertions(+) create mode 100644 flux/flux/__init__.py create mode 100644 flux/flux/autoencoder.py create mode 100644 flux/flux/layers.py create mode 100644 flux/flux/model.py create mode 100644 flux/flux/utils.py diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flux/flux/autoencoder.py b/flux/flux/autoencoder.py new file mode 100644 index 00000000..9d470cb9 --- /dev/null +++ b/flux/flux/autoencoder.py @@ -0,0 +1,360 @@ +from dataclasses import dataclass +from typing import List + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.upsample import upsample_nearest + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: List[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.q = nn.Linear(in_channels, in_channels) + self.k = nn.Linear(in_channels, in_channels) + self.v = nn.Linear(in_channels, in_channels) + self.proj_out = nn.Linear(in_channels, in_channels) + + def __call__(self, x: mx.array) -> mx.array: + B, H, W, C = x.shape + + y = x.reshape(B, 1, -1, C) + y = self.norm(y) + q = self.q(y) + k = self.k(y) + v = self.v(y) + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5)) + y = self.proj_out(y) + + return x + y.reshape(B, H, W, C) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, + dims=out_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Linear(in_channels, out_channels) + + def __call__(self, x): + h = x + h = self.norm1(h) + h = nn.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nn.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def __call__(self, x: mx.array): + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + x = upsample_nearest(x, (2, 2)) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = [] + block_in = self.ch + for i_level in range(self.num_resolutions): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = {} + down["block"] = block + down["attn"] = attn + if i_level != self.num_resolutions - 1: + down["downsample"] = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level]["block"][i_block](hs[-1]) + + # TODO: Remove the attn + if len(self.down[i_level]["attn"]) > 0: + h = self.down[i_level]["attn"][i_block](h) + + hs.append(h) + + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level]["downsample"](hs[-1])) + + # middle + h = hs[-1] + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = [] + for i_level in reversed(range(self.num_resolutions)): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = {} + up["block"] = block + up["attn"] = attn + if i_level != 0: + up["upsample"] = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def __call__(self, z: mx.array): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level]["block"][i_block](h) + + # TODO: Remove the attn + if len(self.up[i_level]["attn"]) > 0: + h = self.up[i_level]["attn"][i_block](h) + + if i_level != 0: + h = self.up[i_level]["upsample"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def __call__(self, z: mx.array): + mean, logvar = mx.split(z, 2, axis=self.chunk_dim) + if self.sample: + std = mx.exp(0.5 * logvar) + eps = mx.random.normal(shape=z.shape, dtype=z.dtype) + return mean + std * eps + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + if w.ndim == 4: + w = w.transpose(0, 2, 3, 1) + w = w.reshape(-1).reshape(w.shape) + if w.shape[1:3] == (1, 1): + w = w.squeeze((1, 2)) + new_weights[k] = w + return new_weights + + def encode(self, x: mx.array): + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: mx.array): + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def __call__(self, x: mx.array): + return self.decode(self.encode(x)) diff --git a/flux/flux/layers.py b/flux/flux/layers.py new file mode 100644 index 00000000..39b8ee0a --- /dev/null +++ b/flux/flux/layers.py @@ -0,0 +1,246 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + return self.out_layer(nn.silu(self.in_layer(x))) + + +class QKNorm(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = nn.RMSNorm(dim) + self.key_norm = nn.RMSNorm(dim) + + def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]: + return self.query_norm(q), self.key_norm(k) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + self.rope = nn.RoPE(head_dim, True, base=10000) + + def __call__(self, x: mx.array, pe: mx.array) -> mx.array: + H = self.num_heads + B, L, _ = x.shape + qkv = self.qkv(x) + q, k, v = mx.split(qkv, 3, axis=-1) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + q = self.rope(q) + k = self.rope(k) + x = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5)) + x = x.transpose(0, 2, 1, 3).reshape(B, L, -1) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: mx.array + scale: mx.array + gate: mx.array + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]: + x = self.lin(nn.silu(x)) + xs = mx.split(x, self.multiplier, axis=-1) + + mod1 = ModulationOut(*xs[:3]) + mod2 = ModulationOut(*xs[3:]) if self.is_double else None + + return mod1, mod2 + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def __call__( + self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array + ) -> Tuple[mx.array, mx.array]: + B, L, _ = img.shape + _, S, _ = txt.shape + H = self.num_heads + + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1) + img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_q, img_k = self.norm(img_q, img_k) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1) + txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_q, txt_k = self.norm(txt_q, txt_k) + + # run actual attention + q = mx.concatenate([txt_q, img_q], axis=2) + k = mx.concatenate([txt_k, img_k], axis=2) + v = mx.concatenate([txt_v, img_v], axis=2) + + q = self.img_attn.rope(q) + k = self.img_attn.rope(k) + attn = mx.fast.scaled_dot_product_attention( + q, k, v, scale=q.shape[-1] ** (-0.5) + ) + attn = attn.transpose(0, 2, 1, 3).reshape(B, L + S, -1) + txt_attn, img_attn = mx.split(attn, [S], axis=1) + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + + return img, txt + + +class SingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: Optional[float] = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.rope = nn.RoPE(head_dim, True, base=10000) + + def __call__(self, x: mx.array, vec: mx.array, pe: mx.array): + B, L, _ = x.shape + H = self.num_heads + + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + + q, k, v, mlp = mx.split( + self.linear1(x_mod), + [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size], + axis=-1, + ) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + + # compute attention + q = self.rope(q) + k = self.rope(k) + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** (-0.5)) + y = y.transpose(0, 2, 1, 3).reshape(B, L, -1) + + # compute activation in mlp stream, cat again and run second linear layer + y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2)) + return x + mod.gate * y + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def __call__(self, x: mx.array, vec: mx.array): + shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/flux/flux/model.py b/flux/flux/model.py new file mode 100644 index 00000000..1bffad1c --- /dev/null +++ b/flux/flux/model.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, +) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + + def __call__( + self, + img: mx.array, + img_ids: mx.array, + txt: mx.array, + txt_ids: mx.array, + timesteps: mx.array, + y: mx.array, + guidance: Optional[mx.array] = None, + ) -> mx.array: + pass diff --git a/flux/flux/utils.py b/flux/flux/utils.py new file mode 100644 index 00000000..979d53d7 --- /dev/null +++ b/flux/flux/utils.py @@ -0,0 +1,138 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +from huggingface_hub import hf_hub_download + +from .autoencoder import AutoEncoder, AutoEncoderParams +from .model import Flux, FluxParams + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: Optional[str] + ae_path: Optional[str] + repo_id: Optional[str] + repo_flow: Optional[str] + repo_ae: Optional[str] + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def load_flow_model(name: str, hf_download: bool = True): + # Get the safetensors file to load + ckpt_path = configs[name].ckpt_path + + # Download if needed + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + # Make the model + model = Flux(configs[name].params) + + # Load the checkpoint if needed + if ckpt_path is not None: + weights = mx.load(ckpt_path) + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) + + return model + + +def load_ae(name: str, hf_download: bool = True): + # Get the safetensors file to load + ckpt_path = configs[name].ae_path + + # Download if needed + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + + # Make the autoencoder + ae = AutoEncoder(configs[name].ae_params) + + # Load the checkpoint if needed + if ckpt_path is not None: + weights = mx.load(ckpt_path) + weights = ae.sanitize(weights) + ae.load_weights(list(weights.items())) + + return ae