mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-17 00:28:07 +08:00
Add FLUX finetuning (#1028)
This commit is contained in:

committed by
GitHub

parent
d72fdeb4ee
commit
a5f2bab070
248
flux/flux/__init__.py
Normal file
248
flux/flux/__init__.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_clip_tokenizer,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str, t5_padding: bool = True):
|
||||
self.dtype = mx.bfloat16
|
||||
self.name = name
|
||||
self.t5_padding = t5_padding
|
||||
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.sampler = FluxSampler(name)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
self.ae.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.clip.parameters(),
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(self.name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
return t5_tokens, clip_tokens
|
||||
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Pack the latent image to 2x2 patches
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
|
||||
# Create positions ids used to positionally encode each patch. Due to
|
||||
# the way RoPE works, this results in an interesting positional
|
||||
# encoding where parts of the feature are holding different positional
|
||||
# information. Namely, the first part holds information independent of
|
||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||
# information and the last one horizontal.
|
||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||
x_ids = mx.stack([i, j, k], axis=-1)
|
||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||
# Prepare the text features
|
||||
txt = self.t5(t5_tokens)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
|
||||
return txt, txt_ids, vec
|
||||
|
||||
def _denoising_loop(
|
||||
self,
|
||||
x_t,
|
||||
x_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
start: float = 1,
|
||||
stop: float = 0,
|
||||
):
|
||||
B = len(x_t)
|
||||
|
||||
def scalar(x):
|
||||
return mx.full((B,), x, dtype=self.dtype)
|
||||
|
||||
guidance = scalar(guidance)
|
||||
timesteps = self.sampler.timesteps(
|
||||
num_steps,
|
||||
x_t.shape[1],
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
for i in range(num_steps):
|
||||
t = timesteps[i]
|
||||
t_prev = timesteps[i + 1]
|
||||
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=scalar(t),
|
||||
guidance=guidance,
|
||||
)
|
||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the conditioning
|
||||
t5_tokens, clip_tokens = self.tokenize(text)
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||
|
||||
# Yield the conditioning for controlled evaluation by the caller
|
||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||
|
||||
# Yield the latent sequences from the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
|
||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||
h, w = latent_size
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
reload_text_encoders: bool = True,
|
||||
progress: bool = True,
|
||||
):
|
||||
latents = self.generate_latents(
|
||||
text, n_images, num_steps, guidance, latent_size, seed
|
||||
)
|
||||
mx.eval(next(latents))
|
||||
|
||||
if reload_text_encoders:
|
||||
self.reload_text_encoders()
|
||||
|
||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||
mx.eval(x_t)
|
||||
|
||||
images = []
|
||||
for i in tqdm(range(len(x_t)), disable=not progress):
|
||||
images.append(self.decode(x_t[i : i + 1]))
|
||||
mx.eval(images[-1])
|
||||
images = mx.concatenate(images, axis=0)
|
||||
mx.eval(images)
|
||||
|
||||
return images
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
x_0: mx.array,
|
||||
t5_features: mx.array,
|
||||
clip_features: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = t5_features
|
||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||
vec = clip_features
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return (pred + x_0 - eps).square().mean()
|
||||
|
||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
def fuse_lora_layers(self):
|
||||
fused_layers = []
|
||||
for name, module in self.flow.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
fused_layers.append((name, module.fuse()))
|
||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
357
flux/flux/autoencoder.py
Normal file
357
flux/flux/autoencoder.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.nn.layers.upsample import upsample_nearest
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: List[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(
|
||||
num_groups=32,
|
||||
dims=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.q = nn.Linear(in_channels, in_channels)
|
||||
self.k = nn.Linear(in_channels, in_channels)
|
||||
self.v = nn.Linear(in_channels, in_channels)
|
||||
self.proj_out = nn.Linear(in_channels, in_channels)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
B, H, W, C = x.shape
|
||||
|
||||
y = x.reshape(B, 1, -1, C)
|
||||
y = self.norm(y)
|
||||
q = self.q(y)
|
||||
k = self.k(y)
|
||||
v = self.v(y)
|
||||
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
|
||||
y = self.proj_out(y)
|
||||
|
||||
return x + y.reshape(B, H, W, C)
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=32,
|
||||
dims=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=32,
|
||||
dims=out_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
x = upsample_nearest(x, (2, 2))
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = []
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = []
|
||||
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = {}
|
||||
down["block"] = block
|
||||
down["attn"] = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down["downsample"] = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = {}
|
||||
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid["attn_1"] = AttnBlock(block_in)
|
||||
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(
|
||||
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level]["block"][i_block](hs[-1])
|
||||
|
||||
# TODO: Remove the attn
|
||||
if len(self.down[i_level]["attn"]) > 0:
|
||||
h = self.down[i_level]["attn"][i_block](h)
|
||||
|
||||
hs.append(h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level]["downsample"](hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid["block_1"](h)
|
||||
h = self.mid["attn_1"](h)
|
||||
h = self.mid["block_2"](h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = {}
|
||||
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid["attn_1"] = AttnBlock(block_in)
|
||||
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = []
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = []
|
||||
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
||||
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = {}
|
||||
up["block"] = block
|
||||
up["attn"] = attn
|
||||
if i_level != 0:
|
||||
up["upsample"] = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(
|
||||
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def __call__(self, z: mx.array):
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid["block_1"](h)
|
||||
h = self.mid["attn_1"](h)
|
||||
h = self.mid["block_2"](h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level]["block"][i_block](h)
|
||||
|
||||
# TODO: Remove the attn
|
||||
if len(self.up[i_level]["attn"]) > 0:
|
||||
h = self.up[i_level]["attn"][i_block](h)
|
||||
|
||||
if i_level != 0:
|
||||
h = self.up[i_level]["upsample"](h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __call__(self, z: mx.array):
|
||||
mean, logvar = mx.split(z, 2, axis=-1)
|
||||
if self.training:
|
||||
std = mx.exp(0.5 * logvar)
|
||||
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
|
||||
return mean + std * eps
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
if w.ndim == 4:
|
||||
w = w.transpose(0, 2, 3, 1)
|
||||
w = w.reshape(-1).reshape(w.shape)
|
||||
if w.shape[1:3] == (1, 1):
|
||||
w = w.squeeze((1, 2))
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def encode(self, x: mx.array):
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: mx.array):
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return self.decode(self.encode(x))
|
154
flux/flux/clip.py
Normal file
154
flux/flux/clip.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
hidden_act: str = "quick_gelu"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return cls(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
hidden_act=config["hidden_act"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput:
|
||||
# The last_hidden_state indexed at the EOS token and possibly projected if
|
||||
# the model has a projection layer
|
||||
pooled_output: Optional[mx.array] = None
|
||||
|
||||
# The full sequence output of the transformer after the final layernorm
|
||||
last_hidden_state: Optional[mx.array] = None
|
||||
|
||||
# A list of hidden states corresponding to the outputs of the transformer layers
|
||||
hidden_states: Optional[List[mx.array]] = None
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
self.act = _ACTIVATIONS[activation]
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = self.act(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
def _get_mask(self, N, dtype):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||
return mask
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for key, w in weights.items():
|
||||
# Remove prefixes
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
new_weights[key] = w
|
||||
|
||||
return new_weights
|
||||
|
||||
def __call__(self, x):
|
||||
# Extract some shapes
|
||||
B, N = x.shape
|
||||
eos_tokens = x.argmax(-1)
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
|
||||
# Compute the features from the transformer
|
||||
mask = self._get_mask(N, x.dtype)
|
||||
hidden_states = []
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
hidden_states.append(x)
|
||||
|
||||
# Apply the final layernorm and return
|
||||
x = self.final_layer_norm(x)
|
||||
last_hidden_state = x
|
||||
|
||||
# Select the EOS token
|
||||
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
||||
|
||||
return CLIPOutput(
|
||||
pooled_output=pooled_output,
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=hidden_states,
|
||||
)
|
302
flux/flux/layers.py
Normal file
302
flux/flux/layers.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def _rope(pos: mx.array, dim: int, theta: float):
|
||||
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
x = pos[..., None] * omega
|
||||
cosx = mx.cos(x)
|
||||
sinx = mx.sin(x)
|
||||
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
|
||||
pe = pe.reshape(*pe.shape[:-1], 2, 2)
|
||||
|
||||
return pe
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def _ab_plus_cd(a, b, c, d):
|
||||
return a * b + c * d
|
||||
|
||||
|
||||
def _apply_rope(x, pe):
|
||||
s = x.shape
|
||||
x = x.reshape(*s[:-1], -1, 1, 2)
|
||||
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
|
||||
return x.reshape(s)
|
||||
|
||||
|
||||
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
|
||||
B, H, L, D = q.shape
|
||||
|
||||
q = _apply_rope(q, pe)
|
||||
k = _apply_rope(k, pe)
|
||||
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
|
||||
|
||||
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
|
||||
def timestep_embedding(
|
||||
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
|
||||
):
|
||||
half = dim // 2
|
||||
freqs = mx.arange(0, half, dtype=mx.float32) / half
|
||||
freqs = freqs * (-math.log(max_period))
|
||||
freqs = mx.exp(freqs)
|
||||
|
||||
x = (time_factor * t)[:, None] * freqs[None]
|
||||
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
|
||||
|
||||
return x.astype(t.dtype)
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def __call__(self, ids: mx.array):
|
||||
n_axes = ids.shape[-1]
|
||||
pe = mx.concatenate(
|
||||
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
axis=-3,
|
||||
)
|
||||
|
||||
return pe[:, None]
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.out_layer(nn.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = nn.RMSNorm(dim)
|
||||
self.key_norm = nn.RMSNorm(dim)
|
||||
|
||||
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
|
||||
return self.query_norm(q), self.key_norm(k)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
|
||||
H = self.num_heads
|
||||
B, L, _ = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
q, k = self.norm(q, k)
|
||||
x = _attention(q, k, v, pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: mx.array
|
||||
scale: mx.array
|
||||
gate: mx.array
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||
x = self.lin(nn.silu(x))
|
||||
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
|
||||
|
||||
mod1 = ModulationOut(*xs[:3])
|
||||
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
||||
|
||||
return mod1, mod2
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(
|
||||
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
||||
)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approx="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(
|
||||
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
||||
)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approx="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
B, L, _ = img.shape
|
||||
_, S, _ = txt.shape
|
||||
H = self.num_heads
|
||||
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
|
||||
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
|
||||
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
||||
|
||||
# run actual attention
|
||||
q = mx.concatenate([txt_q, img_q], axis=2)
|
||||
k = mx.concatenate([txt_k, img_k], axis=2)
|
||||
v = mx.concatenate([txt_v, img_v], axis=2)
|
||||
|
||||
attn = _attention(q, k, v, pe)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp(
|
||||
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
||||
)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp(
|
||||
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approx="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
|
||||
B, L, _ = x.shape
|
||||
H = self.num_heads
|
||||
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
|
||||
q, k, v, mlp = mx.split(
|
||||
self.linear1(x_mod),
|
||||
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
|
||||
axis=-1,
|
||||
)
|
||||
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
q, k = self.norm(q, k)
|
||||
|
||||
# compute attention
|
||||
y = _attention(q, k, v, pe)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
|
||||
return x + mod.gate * y
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, bias=True
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, vec: mx.array):
|
||||
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
76
flux/flux/lora.py
Normal file
76
flux/flux/lora.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_base(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def fuse(self):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
dtype = weight.dtype
|
||||
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = self.scale * self.lora_b.T
|
||||
lora_a = self.lora_a.T
|
||||
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 1.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.linear(x)
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
return y + (self.scale * z).astype(x.dtype)
|
134
flux/flux/model.py
Normal file
134
flux/flux/model.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(
|
||||
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
||||
)
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(
|
||||
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
||||
)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
if params.guidance_embed
|
||||
else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = [
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
|
||||
self.single_blocks = [
|
||||
SingleStreamBlock(
|
||||
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
if k.endswith(".scale"):
|
||||
k = k[:-6] + ".weight"
|
||||
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
|
||||
if f".{seq}." in k:
|
||||
k = k.replace(f".{seq}.", f".{seq}.layers.")
|
||||
break
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img: mx.array,
|
||||
img_ids: mx.array,
|
||||
txt: mx.array,
|
||||
txt_ids: mx.array,
|
||||
timesteps: mx.array,
|
||||
y: mx.array,
|
||||
guidance: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError(
|
||||
"Didn't get guidance strength for guidance distilled model."
|
||||
)
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
||||
pe = self.pe_embedder(ids).astype(img.dtype)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = mx.concatenate([txt, img], axis=1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
return img
|
56
flux/flux/sampler.py
Normal file
56
flux/flux/sampler.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class FluxSampler:
|
||||
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
|
||||
self._base_shift = base_shift
|
||||
self._max_shift = max_shift
|
||||
self._schnell = "schnell" in name
|
||||
|
||||
def _time_shift(self, x, t):
|
||||
x1, x2 = 256, 4096
|
||||
t1, t2 = self._base_shift, self._max_shift
|
||||
exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
|
||||
t = exp_mu / (exp_mu + (1 / t - 1))
|
||||
return t
|
||||
|
||||
@lru_cache
|
||||
def timesteps(
|
||||
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
|
||||
):
|
||||
t = mx.linspace(start, stop, num_steps + 1)
|
||||
|
||||
if self._schnell:
|
||||
t = self._time_shift(image_sequence_length, t)
|
||||
|
||||
return t.tolist()
|
||||
|
||||
def random_timesteps(self, B, L, dtype=mx.float32, key=None):
|
||||
if self._schnell:
|
||||
# TODO: Should we upweigh 1 and 0.75?
|
||||
t = mx.random.randint(1, 5, shape=(B,), key=key)
|
||||
t = t.astype(dtype) / 4
|
||||
else:
|
||||
t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
|
||||
t = self._time_shift(L, t)
|
||||
|
||||
return t
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
return mx.random.normal(shape, dtype=dtype, key=key)
|
||||
|
||||
def add_noise(self, x, t, noise=None, key=None):
|
||||
noise = (
|
||||
noise
|
||||
if noise is not None
|
||||
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
||||
)
|
||||
return x * (1 - t) + t * noise
|
||||
|
||||
def step(self, pred, x_t, t, t_prev):
|
||||
return x_t + (t_prev - t) * pred
|
244
flux/flux/t5.py
Normal file
244
flux/flux/t5.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
_SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
_ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class T5Config:
|
||||
vocab_size: int
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
relative_attention_num_buckets: int
|
||||
d_kv: int
|
||||
d_model: int
|
||||
feed_forward_proj: str
|
||||
tie_word_embeddings: bool
|
||||
|
||||
d_ff: Optional[int] = None
|
||||
num_decoder_layers: Optional[int] = None
|
||||
relative_attention_max_distance: int = 128
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return cls(
|
||||
vocab_size=config["vocab_size"],
|
||||
num_layers=config["num_layers"],
|
||||
num_heads=config["num_heads"],
|
||||
relative_attention_num_buckets=config["relative_attention_num_buckets"],
|
||||
d_kv=config["d_kv"],
|
||||
d_model=config["d_model"],
|
||||
feed_forward_proj=config["feed_forward_proj"],
|
||||
tie_word_embeddings=config["tie_word_embeddings"],
|
||||
d_ff=config.get("d_ff", 4 * config["d_model"]),
|
||||
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
|
||||
relative_attention_max_distance=config.get(
|
||||
"relative_attention_max_distance", 128
|
||||
),
|
||||
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
|
||||
)
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: T5Config, bidirectional: bool):
|
||||
self.bidirectional = bidirectional
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
|
||||
num_buckets = num_buckets // 2 if bidirectional else num_buckets
|
||||
max_exact = num_buckets // 2
|
||||
|
||||
abspos = rpos.abs()
|
||||
is_small = abspos < max_exact
|
||||
|
||||
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
|
||||
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
|
||||
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
|
||||
|
||||
buckets = mx.where(is_small, abspos, buckets_large)
|
||||
if bidirectional:
|
||||
buckets = buckets + (rpos > 0) * num_buckets
|
||||
else:
|
||||
buckets = buckets * (rpos < 0)
|
||||
|
||||
return buckets
|
||||
|
||||
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(offset, query_length)[:, None]
|
||||
memory_position = mx.arange(key_length)[None, :]
|
||||
|
||||
# shape (query_length, key_length)
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
|
||||
# shape (query_length, key_length, num_heads)
|
||||
values = self.embeddings(relative_position_bucket)
|
||||
|
||||
# shape (num_heads, query_length, key_length)
|
||||
return values.transpose(2, 0, 1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
queries: mx.array,
|
||||
keys: mx.array,
|
||||
values: mx.array,
|
||||
mask: Optional[mx.array],
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
values_hat = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
|
||||
)
|
||||
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.gated = config.feed_forward_proj.startswith("gated")
|
||||
if self.gated:
|
||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
else:
|
||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||
if activation == "relu":
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
self.act = nn.gelu
|
||||
elif activation == "silu":
|
||||
self.act = nn.silu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
def __call__(self, x):
|
||||
if self.gated:
|
||||
hidden_act = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_act * hidden_linear
|
||||
else:
|
||||
x = self.act(self.wi(x))
|
||||
return self.wo(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y, _ = self.attention(y, y, y, mask=mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.dense(y)
|
||||
return x + y
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
pos_bias = pos_bias.astype(x.dtype)
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=pos_bias)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
for old, new in _SHARED_REPLACEMENT_PATTERNS:
|
||||
k = k.replace(old, new)
|
||||
if k.startswith("encoder."):
|
||||
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
|
||||
k = k.replace(old, new)
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def __call__(self, inputs: mx.array):
|
||||
return self.encoder(self.wte(inputs))
|
185
flux/flux/tokenizers.py
Normal file
185
flux/flux/tokenizers.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import regex
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
class CLIPTokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab, max_length=77):
|
||||
self.max_length = max_length
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(
|
||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||
)
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
if len(tokens) > self.max_length:
|
||||
tokens = tokens[: self.max_length]
|
||||
if append_eos:
|
||||
tokens[-1] = self.eos_token
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
if not isinstance(text, list):
|
||||
return self.encode([text])
|
||||
|
||||
tokens = self.tokenize(text)
|
||||
length = max(len(t) for t in tokens)
|
||||
for t in tokens:
|
||||
t.extend([self.eos_token] * (length - len(t)))
|
||||
|
||||
return mx.array(tokens)
|
||||
|
||||
|
||||
class T5Tokenizer:
|
||||
def __init__(self, model_file, max_length=512):
|
||||
self._tokenizer = SentencePieceProcessor(model_file)
|
||||
self.max_length = max_length
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.pad_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return self._tokenizer.pad_id()
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.bos_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self._tokenizer.bos_id()
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.eos_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self._tokenizer.eos_id()
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
|
||||
|
||||
tokens = self._tokenizer.encode(text)
|
||||
|
||||
if prepend_bos and self.bos_token >= 0:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos and self.eos_token >= 0:
|
||||
tokens.append(self.eos_token)
|
||||
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
|
||||
tokens += [self.pad_token] * (self.max_length - len(tokens))
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, text, pad=True):
|
||||
if not isinstance(text, list):
|
||||
return self.encode([text], pad=pad)
|
||||
|
||||
pad_token = self.pad_token if self.pad_token >= 0 else 0
|
||||
tokens = self.tokenize(text, pad=pad)
|
||||
length = max(len(t) for t in tokens)
|
||||
for t in tokens:
|
||||
t.extend([pad_token] * (length - len(t)))
|
||||
|
||||
return mx.array(tokens)
|
209
flux/flux/utils.py
Normal file
209
flux/flux/utils.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .clip import CLIPTextModel, CLIPTextModelConfig
|
||||
from .model import Flux, FluxParams
|
||||
from .t5 import T5Config, T5Encoder
|
||||
from .tokenizers import CLIPTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: Optional[str]
|
||||
ae_path: Optional[str]
|
||||
repo_id: Optional[str]
|
||||
repo_flow: Optional[str]
|
||||
repo_ae: Optional[str]
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-dev.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-schnell": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||
repo_flow="flux1-schnell.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def load_flow_model(name: str, hf_download: bool = True):
|
||||
# Get the safetensors file to load
|
||||
ckpt_path = configs[name].ckpt_path
|
||||
|
||||
# Download if needed
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_flow is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
||||
|
||||
# Make the model
|
||||
model = Flux(configs[name].params)
|
||||
|
||||
# Load the checkpoint if needed
|
||||
if ckpt_path is not None:
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_ae(name: str, hf_download: bool = True):
|
||||
# Get the safetensors file to load
|
||||
ckpt_path = configs[name].ae_path
|
||||
|
||||
# Download if needed
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_ae is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
||||
|
||||
# Make the autoencoder
|
||||
ae = AutoEncoder(configs[name].ae_params)
|
||||
|
||||
# Load the checkpoint if needed
|
||||
if ckpt_path is not None:
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = ae.sanitize(weights)
|
||||
ae.load_weights(list(weights.items()))
|
||||
|
||||
return ae
|
||||
|
||||
|
||||
def load_clip(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
|
||||
with open(config_path) as f:
|
||||
config = CLIPTextModelConfig.from_dict(json.load(f))
|
||||
|
||||
# Make the clip text encoder
|
||||
clip = CLIPTextModel(config)
|
||||
|
||||
# Load the weights
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = clip.sanitize(weights)
|
||||
clip.load_weights(list(weights.items()))
|
||||
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
|
||||
with open(config_path) as f:
|
||||
config = T5Config.from_dict(json.load(f))
|
||||
|
||||
# Make the T5 model
|
||||
t5 = T5Encoder(config)
|
||||
|
||||
# Load the weights
|
||||
model_index = hf_hub_download(
|
||||
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
|
||||
)
|
||||
weight_files = set()
|
||||
with open(model_index) as f:
|
||||
for _, w in json.load(f)["weight_map"].items():
|
||||
weight_files.add(w)
|
||||
weights = {}
|
||||
for w in weight_files:
|
||||
w = f"text_encoder_2/{w}"
|
||||
w = hf_hub_download(configs[name].repo_id, w)
|
||||
weights.update(mx.load(w))
|
||||
weights = t5.sanitize(weights)
|
||||
t5.load_weights(list(weights.items()))
|
||||
|
||||
return t5
|
||||
|
||||
|
||||
def load_clip_tokenizer(name: str):
|
||||
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
|
||||
|
||||
|
||||
def load_t5_tokenizer(name: str, pad: bool = True):
|
||||
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
||||
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
Reference in New Issue
Block a user