From ca883431188de5428c32f1c9a13fdeefbe8d6b42 Mon Sep 17 00:00:00 2001 From: madroid Date: Sun, 13 Oct 2024 01:17:37 +0800 Subject: [PATCH] FLUX: extract FluxPipeline from __init__.py to flux.py --- flux/flux/__init__.py | 238 +--------------------------------------- flux/flux/flux.py | 246 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 235 deletions(-) create mode 100644 flux/flux/flux.py diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 33bd815f..b1122d75 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -1,15 +1,10 @@ # Copyright © 2024 Apple Inc. -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten -from tqdm import tqdm - -from .datasets import load_dataset +from .datasets import Dataset, load_dataset +from .flux import FluxPipeline from .lora import LoRALinear from .sampler import FluxSampler +from .trainer import Trainer from .utils import ( load_ae, load_clip, @@ -18,230 +13,3 @@ from .utils import ( load_t5, load_t5_tokenizer, ) - - -class FluxPipeline: - def __init__(self, name: str, t5_padding: bool = True): - self.dtype = mx.bfloat16 - self.name = name - self.t5_padding = t5_padding - - self.ae = load_ae(name) - self.flow = load_flow_model(name) - self.clip = load_clip(name) - self.clip_tokenizer = load_clip_tokenizer(name) - self.t5 = load_t5(name) - self.t5_tokenizer = load_t5_tokenizer(name) - self.sampler = FluxSampler(name) - - def ensure_models_are_loaded(self): - mx.eval( - self.ae.parameters(), - self.flow.parameters(), - self.clip.parameters(), - self.t5.parameters(), - ) - - def reload_text_encoders(self): - self.t5 = load_t5(self.name) - self.clip = load_clip(self.name) - - def tokenize(self, text): - t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) - clip_tokens = self.clip_tokenizer.encode(text) - return t5_tokens, clip_tokens - - def _prepare_latent_images(self, x): - b, h, w, c = x.shape - - # Pack the latent image to 2x2 patches - x = x.reshape(b, h // 2, 2, w // 2, 2, c) - x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) - - # Create positions ids used to positionally encode each patch. Due to - # the way RoPE works, this results in an interesting positional - # encoding where parts of the feature are holding different positional - # information. Namely, the first part holds information independent of - # the spatial position (hence 0s), the 2nd part holds vertical spatial - # information and the last one horizontal. - i = mx.zeros((h // 2, w // 2), dtype=mx.int32) - j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") - x_ids = mx.stack([i, j, k], axis=-1) - x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) - - return x, x_ids - - def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): - # Prepare the text features - txt = self.t5(t5_tokens) - if len(txt) == 1 and n_images > 1: - txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) - txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) - - # Prepare the clip text features - vec = self.clip(clip_tokens).pooled_output - if len(vec) == 1 and n_images > 1: - vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) - - return txt, txt_ids, vec - - def _denoising_loop( - self, - x_t, - x_ids, - txt, - txt_ids, - vec, - num_steps: int = 35, - guidance: float = 4.0, - start: float = 1, - stop: float = 0, - ): - B = len(x_t) - - def scalar(x): - return mx.full((B,), x, dtype=self.dtype) - - guidance = scalar(guidance) - timesteps = self.sampler.timesteps( - num_steps, - x_t.shape[1], - start=start, - stop=stop, - ) - for i in range(num_steps): - t = timesteps[i] - t_prev = timesteps[i + 1] - - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=scalar(t), - guidance=guidance, - ) - x_t = self.sampler.step(pred, x_t, t, t_prev) - - yield x_t - - def generate_latents( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - ): - # Set the PRNG state - if seed is not None: - mx.random.seed(seed) - - # Create the latent variables - x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) - x_T, x_ids = self._prepare_latent_images(x_T) - - # Get the conditioning - t5_tokens, clip_tokens = self.tokenize(text) - txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) - - # Yield the conditioning for controlled evaluation by the caller - yield (x_T, x_ids, txt, txt_ids, vec) - - # Yield the latent sequences from the denoising loop - yield from self._denoising_loop( - x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance - ) - - def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): - h, w = latent_size - x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) - x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) - x = self.ae.decode(x) - return mx.clip(x + 1, 0, 2) * 0.5 - - def generate_images( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - reload_text_encoders: bool = True, - progress: bool = True, - ): - latents = self.generate_latents( - text, n_images, num_steps, guidance, latent_size, seed - ) - mx.eval(next(latents)) - - if reload_text_encoders: - self.reload_text_encoders() - - for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): - mx.eval(x_t) - - images = [] - for i in tqdm(range(len(x_t)), disable=not progress): - images.append(self.decode(x_t[i: i + 1])) - mx.eval(images[-1]) - images = mx.concatenate(images, axis=0) - mx.eval(images) - - return images - - def training_loss( - self, - x_0: mx.array, - t5_features: mx.array, - clip_features: mx.array, - guidance: mx.array, - ): - # Get the text conditioning - txt = t5_features - txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) - vec = clip_features - - # Prepare the latent input - x_0, x_ids = self._prepare_latent_images(x_0) - - # Forward process - t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=self.dtype) - x_t = self.sampler.add_noise(x_0, t, noise=eps) - x_t = mx.stop_gradient(x_t) - - # Do the denoising - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t, - guidance=guidance, - ) - - return (pred + x_0 - eps).square().mean() - - def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): - """Swap the linear layers in the transformer blocks with LoRA layers.""" - all_blocks = self.flow.double_blocks + self.flow.single_blocks - all_blocks.reverse() - num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) - for i, block in zip(range(num_blocks), all_blocks): - loras = [] - for name, module in block.named_modules(): - if isinstance(module, nn.Linear): - loras.append((name, LoRALinear.from_base(module, r=rank))) - block.update_modules(tree_unflatten(loras)) - - def fuse_lora_layers(self): - fused_layers = [] - for name, module in self.flow.named_modules(): - if isinstance(module, LoRALinear): - fused_layers.append((name, module.fuse())) - self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/flux.py b/flux/flux/flux.py new file mode 100644 index 00000000..b29a3c55 --- /dev/null +++ b/flux/flux/flux.py @@ -0,0 +1,246 @@ +# Copyright © 2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name) + self.clip = load_clip(self.name) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress): + images.append(self.decode(x_t[i: i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers))