Refactor the pipeline

This commit is contained in:
Angelos Katharopoulos 2024-10-02 14:34:45 -07:00
parent aefe60e79d
commit 9eef46e645
3 changed files with 145 additions and 75 deletions

View File

@ -5,6 +5,7 @@ from typing import Tuple
import mlx.core as mx
from tqdm import tqdm
from .sampler import FluxSampler
from .utils import (
load_ae,
load_clip,
@ -17,6 +18,7 @@ from .utils import (
class FluxPipeline:
def __init__(self, name: str):
self.dtype = mx.bfloat16
self.name = name
self.ae = load_ae(name)
self.flow = load_flow_model(name)
@ -24,7 +26,7 @@ class FluxPipeline:
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.dtype = mx.bfloat16
self.sampler = FluxSampler(shift="schnell" not in name)
def ensure_models_are_loaded(self):
mx.eval(
@ -34,65 +36,82 @@ class FluxPipeline:
self.t5.parameters(),
)
def _prior(self, n_images: int = 1, latent_size: Tuple[int, int] = (64, 64)):
return mx.random.normal(
shape=(n_images, *latent_size, 16),
dtype=self.dtype,
)
def _prepare(self, x, text):
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Prepare the latent image input and its ids for positional encoding
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
x_ids = mx.concatenate(
[
mx.zeros((h // 2, w // 2, 1), dtype=mx.int32),
mx.broadcast_to(mx.arange(h // 2)[:, None, None], (h // 2, w // 2, 1)),
mx.broadcast_to(mx.arange(w // 2)[None, :, None], (h // 2, w // 2, 1)),
],
axis=-1,
)
x_ids = mx.broadcast_to(x_ids.reshape(1, h * w // 4, 3), (b, h * w // 4, 3))
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, text):
# Prepare the text features
t5_tokens = mx.array([self.t5_tokenizer.tokenize(text)])
t5_tokens = self.t5_tokenizer.encode(text)
txt = self.t5(t5_tokens)
txt = mx.broadcast_to(txt, (b, *txt.shape[1:]))
txt_ids = mx.zeros((b, txt.shape[1], 3), dtype=mx.int32)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
clip_tokens = mx.array([self.clip_tokenizer.tokenize(text)])
clip_tokens = self.clip_tokenizer.encode(text)
vec = self.clip(clip_tokens).pooled_output
vec = mx.broadcast_to(vec, (b, *vec.shape[1:]))
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return {
"img": x,
"img_ids": x_ids,
"txt": txt,
"txt_ids": txt_ids,
"vec": vec,
}
return txt, txt_ids, vec
def _get_shedule(
def _denoising_loop(
self,
num_steps,
image_seq_len,
base_shift: float = 0.5,
max_shift: float = 1.5,
shift: bool = True,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
timesteps = mx.linspace(1, 0, num_steps + 1)
B = len(x_t)
if shift:
x = image_seq_len
x1, x2 = 256, 4096
y1, y2 = base_shift, max_shift
mu = (x - x1) * (y2 - y1) / (x2 - x1) + y1
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
return timesteps
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
@ -108,38 +127,21 @@ class FluxPipeline:
mx.random.seed(seed)
# Create the latent variables
x_T = self._prior(n_images, latent_size)
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the initial inputs
inputs = self._prepare(x_T, text)
# Get the conditioning
txt, txt_ids, vec = self._prepare_conditioning(n_images, text)
# Perform the denoising loop
mx.eval(inputs)
timesteps = self._get_shedule(
num_steps, x_T.shape[1], shift="schnell" not in self.name
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
timesteps = timesteps.tolist()
guidance = mx.full((n_images,), guidance, dtype=self.dtype)
for t, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
t_arr = mx.full((n_images,), t, dtype=self.dtype)
pred = self.flow(
img=inputs["img"],
img_ids=inputs["img_ids"],
txt=inputs["txt"],
txt_ids=inputs["txt_ids"],
y=inputs["vec"],
timesteps=t_arr,
guidance=guidance,
)
inputs["img"] = inputs["img"] + (t_prev - t) * pred
mx.eval(inputs["img"])
img = inputs["img"]
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
img = img.reshape(n_images, h // 2, w // 2, -1, 2, 2)
img = img.transpose(0, 1, 4, 2, 5, 3).reshape(n_images, h, w, -1)
img = self.ae.decode(img)
mx.eval(img)
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
x = (mx.clip(x + 1, 0, 2) * 127.5).astype(mx.uint8)
return ((mx.clip(img, -1, 1) + 1) * 127.5).astype(mx.uint8)
return x

41
flux/flux/sampler.py Normal file
View File

@ -0,0 +1,41 @@
from functools import lru_cache
import mlx.core as mx
class FluxSampler:
def __init__(
self, base_shift: float = 0.5, max_shift: float = 1.5, shift: bool = True
):
self._base_shift = base_shift
self._max_shift = max_shift
self._shift = shift
@lru_cache
def timesteps(
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
):
t = mx.linspace(start, stop, num_steps + 1)
if self._shift:
x = image_sequence_length
x1, x2 = 256, 4096
y1, y2 = self._base_shift, self._max_shift
mu = (x - x1) * (y2 - y1) / (x2 - x1) + y1
t = mx.exp(mu) / (mx.exp(mu) + (1 / t - 1))
return t.tolist()
def sample_prior(self, shape, dtype=mx.float32, key=None):
return mx.random.normal(shape, dtype=dtype, key=key)
def add_noise(self, x, t, noise=None, key=None):
noise = (
noise
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev):
return x_t + (t_prev - t) * pred

View File

@ -1,3 +1,4 @@
import mlx.core as mx
import regex
from sentencepiece import SentencePieceProcessor
@ -104,6 +105,17 @@ class CLIPTokenizer:
return tokens
def encode(self, text):
if not isinstance(text, list):
return self.encode([text])
tokens = self.tokenize(text)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([self.eos_token] * (length - len(t)))
return mx.array(tokens)
class T5Tokenizer:
def __init__(self, model_file):
@ -132,11 +144,26 @@ class T5Tokenizer:
return self._tokenizer.eos_id()
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
tokens = self._tokenizer.encode(text)
if prepend_bos and self.bos_token > 0:
if prepend_bos and self.bos_token >= 0:
tokens = [self.bos_token] + tokens
if append_eos and self.eos_token > 0:
if append_eos and self.eos_token >= 0:
tokens.append(self.eos_token)
return tokens
def encode(self, text):
if not isinstance(text, list):
return self.encode([text])
eos_token = self.eos_token if self.eos_token >= 0 else 0
tokens = self.tokenize(text)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([eos_token] * (length - len(t)))
return mx.array(tokens)