From ecd8828e337d5798ba3e9f802fa654fc486adbed Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 10 Oct 2024 00:35:44 -0700 Subject: [PATCH] Further refactoring --- flux/dreambooth.py | 62 +++++++------------------------------------ flux/flux/__init__.py | 50 ++++++++++++++++++++++++++++++++-- flux/flux/lora.py | 2 +- 3 files changed, 59 insertions(+), 55 deletions(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 7640418f..8cf7a995 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -87,63 +87,21 @@ class FinetuningDataset: yield xs[indices], t5[indices], clip[indices] -def linear_to_lora_layers(flux, args): - """Swap the linear layers in the transformer blocks with LoRA layers.""" - rank = args.lora_rank - all_blocks = flux.flow.double_blocks + flux.flow.single_blocks - all_blocks.reverse() - num_blocks = args.lora_blocks if args.lora_blocks > 0 else len(all_blocks) - for i, block in zip(range(num_blocks), all_blocks): - loras = [] - for name, module in block.named_modules(): - if isinstance(module, nn.Linear): - loras.append((name, LoRALinear.from_base(module, r=rank))) - block.update_modules(tree_unflatten(loras)) - - def generate_progress_images(iteration, flux, args): """Generate images to monitor the progress of the finetuning.""" - - def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True): - with random_state(seed): - latents = flux.generate_latents( - prompt, - n_images=n_images, - num_steps=steps, - ) - for x_t in tqdm(latents, total=args.progress_steps, leave=leave): - mx.eval(x_t) - - return x_t - - def decode_latents(flux, x): - decoded = [] - for i in tqdm(range(len(x))): - decoded.append(flux.decode(x[i : i + 1])) - mx.eval(decoded[-1]) - return mx.concatenate(decoded, axis=0) - out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"{iteration:07d}_progress.png" print(f"Generating {str(out_file)}", flush=True) - # Generate the latent vectors using diffusion - n_images = 4 - latents = generate_latents( - flux, - n_images, - args.progress_prompt, - args.progress_steps, - seed=42 + mx.distributed.init().rank(), - ) - - # Reload the text encoders to reduce the memory use during training - flux.reload_text_encoders() - - # Arrange them on a grid + # Generate some images and arrange them in a grid n_rows = 2 - x = decode_latents(flux, latents) + n_images = 4 + x = flux.generate_images( + args.progress_prompt, + n_images, + args.progress_steps, + ) x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) B, H, W, C = x.shape x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) @@ -166,8 +124,8 @@ def save_adapters(iteration, flux, args): str(out_file), dict(tree_flatten(flux.flow.trainable_parameters())), metadata={ - "lora_rank": args.lora_rank, - "lora_blocks": args.lora_blocks, + "lora_rank": str(args.lora_rank), + "lora_blocks": str(args.lora_blocks), }, ) @@ -269,7 +227,7 @@ if __name__ == "__main__": flux = FluxPipeline("flux-" + args.model) flux.flow.freeze() with random_state(0x0F0F0F0F): - linear_to_lora_layers(flux, args) + flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks) # Report how many parameters we are training trainable_params = tree_reduce( diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index aba1e181..30d0410a 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -3,8 +3,11 @@ import time from typing import Tuple import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten from tqdm import tqdm +from .lora import LoRALinear from .sampler import FluxSampler from .utils import ( load_ae, @@ -38,7 +41,7 @@ class FluxPipeline: def reload_text_encoders(self): self.t5 = load_t5(self.name) - self.clip = load_clip(name) + self.clip = load_clip(self.name) def tokenize(self, text): t5_tokens = self.t5_tokenizer.encode(text) @@ -156,6 +159,37 @@ class FluxPipeline: x = self.ae.decode(x) return mx.clip(x + 1, 0, 2) * 0.5 + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + def training_loss( self, x_0: mx.array, @@ -171,7 +205,7 @@ class FluxPipeline: # Prepare the latent input x_0, x_ids = self._prepare_latent_images(x_0) - # Forward process (we use rf/lognorm(0, 1)) + # Forward process t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) eps = mx.random.normal(x_0.shape, dtype=self.dtype) x_t = self.sampler.add_noise(x_0, t, noise=eps) @@ -189,3 +223,15 @@ class FluxPipeline: ) return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) diff --git a/flux/flux/lora.py b/flux/flux/lora.py index 92314440..2bf6fb69 100644 --- a/flux/flux/lora.py +++ b/flux/flux/lora.py @@ -23,7 +23,7 @@ class LoRALinear(nn.Module): lora_lin.linear = linear return lora_lin - def fuse(self, de_quantize: bool = False): + def fuse(self): linear = self.linear bias = "bias" in linear weight = linear.weight