mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Refactor the pipeline
This commit is contained in:
parent
aefe60e79d
commit
9eef46e645
@ -5,6 +5,7 @@ from typing import Tuple
|
||||
import mlx.core as mx
|
||||
from tqdm import tqdm
|
||||
|
||||
from .sampler import FluxSampler
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
@ -17,6 +18,7 @@ from .utils import (
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str):
|
||||
self.dtype = mx.bfloat16
|
||||
self.name = name
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
@ -24,7 +26,7 @@ class FluxPipeline:
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.dtype = mx.bfloat16
|
||||
self.sampler = FluxSampler(shift="schnell" not in name)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
@ -34,65 +36,82 @@ class FluxPipeline:
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def _prior(self, n_images: int = 1, latent_size: Tuple[int, int] = (64, 64)):
|
||||
return mx.random.normal(
|
||||
shape=(n_images, *latent_size, 16),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def _prepare(self, x, text):
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Prepare the latent image input and its ids for positional encoding
|
||||
# Pack the latent image to 2x2 patches
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
x_ids = mx.concatenate(
|
||||
[
|
||||
mx.zeros((h // 2, w // 2, 1), dtype=mx.int32),
|
||||
mx.broadcast_to(mx.arange(h // 2)[:, None, None], (h // 2, w // 2, 1)),
|
||||
mx.broadcast_to(mx.arange(w // 2)[None, :, None], (h // 2, w // 2, 1)),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
x_ids = mx.broadcast_to(x_ids.reshape(1, h * w // 4, 3), (b, h * w // 4, 3))
|
||||
|
||||
# Create positions ids used to positionally encode each patch. Due to
|
||||
# the way RoPE works, this results in an interesting positional
|
||||
# encoding where parts of the feature are holding different positional
|
||||
# information. Namely, the first part holds information independent of
|
||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||
# information and the last one horizontal.
|
||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||
x_ids = mx.stack([i, j, k], axis=-1)
|
||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, text):
|
||||
# Prepare the text features
|
||||
t5_tokens = mx.array([self.t5_tokenizer.tokenize(text)])
|
||||
t5_tokens = self.t5_tokenizer.encode(text)
|
||||
txt = self.t5(t5_tokens)
|
||||
txt = mx.broadcast_to(txt, (b, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((b, txt.shape[1], 3), dtype=mx.int32)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
clip_tokens = mx.array([self.clip_tokenizer.tokenize(text)])
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
vec = mx.broadcast_to(vec, (b, *vec.shape[1:]))
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
|
||||
return {
|
||||
"img": x,
|
||||
"img_ids": x_ids,
|
||||
"txt": txt,
|
||||
"txt_ids": txt_ids,
|
||||
"vec": vec,
|
||||
}
|
||||
return txt, txt_ids, vec
|
||||
|
||||
def _get_shedule(
|
||||
def _denoising_loop(
|
||||
self,
|
||||
num_steps,
|
||||
image_seq_len,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.5,
|
||||
shift: bool = True,
|
||||
x_t,
|
||||
x_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
start: float = 1,
|
||||
stop: float = 0,
|
||||
):
|
||||
timesteps = mx.linspace(1, 0, num_steps + 1)
|
||||
B = len(x_t)
|
||||
|
||||
if shift:
|
||||
x = image_seq_len
|
||||
x1, x2 = 256, 4096
|
||||
y1, y2 = base_shift, max_shift
|
||||
mu = (x - x1) * (y2 - y1) / (x2 - x1) + y1
|
||||
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
|
||||
def scalar(x):
|
||||
return mx.full((B,), x, dtype=self.dtype)
|
||||
|
||||
return timesteps
|
||||
guidance = scalar(guidance)
|
||||
timesteps = self.sampler.timesteps(
|
||||
num_steps,
|
||||
x_t.shape[1],
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
for i in range(num_steps):
|
||||
t = timesteps[i]
|
||||
t_prev = timesteps[i + 1]
|
||||
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=scalar(t),
|
||||
guidance=guidance,
|
||||
)
|
||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
@ -108,38 +127,21 @@ class FluxPipeline:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self._prior(n_images, latent_size)
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the initial inputs
|
||||
inputs = self._prepare(x_T, text)
|
||||
# Get the conditioning
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, text)
|
||||
|
||||
# Perform the denoising loop
|
||||
mx.eval(inputs)
|
||||
timesteps = self._get_shedule(
|
||||
num_steps, x_T.shape[1], shift="schnell" not in self.name
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
timesteps = timesteps.tolist()
|
||||
guidance = mx.full((n_images,), guidance, dtype=self.dtype)
|
||||
for t, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
|
||||
t_arr = mx.full((n_images,), t, dtype=self.dtype)
|
||||
pred = self.flow(
|
||||
img=inputs["img"],
|
||||
img_ids=inputs["img_ids"],
|
||||
txt=inputs["txt"],
|
||||
txt_ids=inputs["txt_ids"],
|
||||
y=inputs["vec"],
|
||||
timesteps=t_arr,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
inputs["img"] = inputs["img"] + (t_prev - t) * pred
|
||||
mx.eval(inputs["img"])
|
||||
|
||||
img = inputs["img"]
|
||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||
h, w = latent_size
|
||||
img = img.reshape(n_images, h // 2, w // 2, -1, 2, 2)
|
||||
img = img.transpose(0, 1, 4, 2, 5, 3).reshape(n_images, h, w, -1)
|
||||
img = self.ae.decode(img)
|
||||
mx.eval(img)
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
x = (mx.clip(x + 1, 0, 2) * 127.5).astype(mx.uint8)
|
||||
|
||||
return ((mx.clip(img, -1, 1) + 1) * 127.5).astype(mx.uint8)
|
||||
return x
|
||||
|
41
flux/flux/sampler.py
Normal file
41
flux/flux/sampler.py
Normal file
@ -0,0 +1,41 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class FluxSampler:
|
||||
def __init__(
|
||||
self, base_shift: float = 0.5, max_shift: float = 1.5, shift: bool = True
|
||||
):
|
||||
self._base_shift = base_shift
|
||||
self._max_shift = max_shift
|
||||
self._shift = shift
|
||||
|
||||
@lru_cache
|
||||
def timesteps(
|
||||
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
|
||||
):
|
||||
t = mx.linspace(start, stop, num_steps + 1)
|
||||
|
||||
if self._shift:
|
||||
x = image_sequence_length
|
||||
x1, x2 = 256, 4096
|
||||
y1, y2 = self._base_shift, self._max_shift
|
||||
mu = (x - x1) * (y2 - y1) / (x2 - x1) + y1
|
||||
t = mx.exp(mu) / (mx.exp(mu) + (1 / t - 1))
|
||||
|
||||
return t.tolist()
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
return mx.random.normal(shape, dtype=dtype, key=key)
|
||||
|
||||
def add_noise(self, x, t, noise=None, key=None):
|
||||
noise = (
|
||||
noise
|
||||
if noise is not None
|
||||
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
||||
)
|
||||
return x * (1 - t) + t * noise
|
||||
|
||||
def step(self, pred, x_t, t, t_prev):
|
||||
return x_t + (t_prev - t) * pred
|
@ -1,3 +1,4 @@
|
||||
import mlx.core as mx
|
||||
import regex
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
@ -104,6 +105,17 @@ class CLIPTokenizer:
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
if not isinstance(text, list):
|
||||
return self.encode([text])
|
||||
|
||||
tokens = self.tokenize(text)
|
||||
length = max(len(t) for t in tokens)
|
||||
for t in tokens:
|
||||
t.extend([self.eos_token] * (length - len(t)))
|
||||
|
||||
return mx.array(tokens)
|
||||
|
||||
|
||||
class T5Tokenizer:
|
||||
def __init__(self, model_file):
|
||||
@ -132,11 +144,26 @@ class T5Tokenizer:
|
||||
return self._tokenizer.eos_id()
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
tokens = self._tokenizer.encode(text)
|
||||
|
||||
if prepend_bos and self.bos_token > 0:
|
||||
if prepend_bos and self.bos_token >= 0:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos and self.eos_token > 0:
|
||||
if append_eos and self.eos_token >= 0:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
if not isinstance(text, list):
|
||||
return self.encode([text])
|
||||
|
||||
eos_token = self.eos_token if self.eos_token >= 0 else 0
|
||||
tokens = self.tokenize(text)
|
||||
length = max(len(t) for t in tokens)
|
||||
for t in tokens:
|
||||
t.extend([eos_token] * (length - len(t)))
|
||||
|
||||
return mx.array(tokens)
|
||||
|
Loading…
Reference in New Issue
Block a user