From 7cffcdcaff6b1947d7648954b4ebc845184a7e91 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 3 Oct 2024 11:35:56 -0700 Subject: [PATCH] Flux lora training --- flux/dreambooth.py | 212 +++++++++++++++++++++++++++++++++ flux/flux/__init__.py | 52 +++++++- flux/flux/autoencoder.py | 9 +- flux/flux/lora.py | 74 ++++++++++++ flux/txt2image.py | 116 ++++++++++++++++++ stable_diffusion/dreambooth.py | 4 +- 6 files changed, 452 insertions(+), 15 deletions(-) create mode 100644 flux/dreambooth.py create mode 100644 flux/flux/lora.py create mode 100644 flux/txt2image.py diff --git a/flux/dreambooth.py b/flux/dreambooth.py new file mode 100644 index 00000000..fe36a82b --- /dev/null +++ b/flux/dreambooth.py @@ -0,0 +1,212 @@ +import argparse +import time +from functools import partial +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +from mlx.utils import tree_reduce, tree_unflatten +from PIL import Image +from tqdm import tqdm + +from flux import FluxPipeline +from flux.lora import LoRALinear + + +def linear_to_lora_layers(flux): + lora_layers = [] + for name, mod in flux.flow.named_modules(): + if ".img_attn" not in name: + continue + if ".qkv" in name or ".proj" in name: + lora_layers.append((name, LoRALinear.from_base(mod, r=32))) + flux.flow.update_modules(tree_unflatten(lora_layers)) + + +def extract_latent_vectors(flux, image_folder): + flux.ae.eval() + latents = [] + for image in tqdm(Path(image_folder).iterdir()): + img = Image.open(image) + img = mx.array(np.array(img)) + img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1 + x_0 = flux.ae.encode(img[None]) + x_0 = x_0.astype(flux.dtype) + mx.eval(x_0) + latents.append(x_0) + return mx.concatenate(latents) + + +def decode_latents(flux, x): + decoded = [] + for i in tqdm(range(len(x))): + decoded.append(flux.decode(x[i : i + 1])) + mx.eval(decoded[-1]) + return mx.concatenate(decoded, axis=0) + + +def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True): + latents = flux.generate_latents( + prompt, + n_images=n_images, + num_steps=steps, + seed=seed, + ) + for x_t in tqdm(latents, total=args.progress_steps, leave=leave): + mx.eval(x_t) + + return x_t + + +def generate_progress_images(iteration, flux, args): + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_file = out_dir / f"out_{iteration:03d}.png" + print(f"Generating {str(out_file)}") + # Generate the latent vectors using diffusion + n_images = 4 + latents = generate_latents( + flux, + n_images, + args.progress_prompt, + args.progress_steps, + seed=42, + ) + + # Arrange them on a grid + n_rows = 2 + x = decode_latents(flux, latents) + x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(n_rows * H, B // n_rows * W, C) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(out_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Finetune Flux to generate images with a specific subject" + ) + + parser.add_argument( + "--model", + default="dev", + choices=[ + "dev", + "schnell", + ], + help="Which flux model to train", + ) + parser.add_argument( + "--iterations", + type=int, + default=400, + help="How many iterations to train for", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size to use when training the stable diffusion model", + ) + parser.add_argument( + "--progress-prompt", + help="Use this prompt when generating images for evaluation", + ) + parser.add_argument( + "--progress-steps", + type=int, + default=50, + help="Use this many steps when generating images for evaluation", + ) + parser.add_argument( + "--progress-every", + type=int, + default=50, + help="Generate images every PROGRESS_EVERY steps", + ) + parser.add_argument( + "--checkpoint-every", + type=int, + default=50, + help="Save the model every CHECKPOINT_EVERY steps", + ) + parser.add_argument( + "--learning-rate", type=float, default="1e-6", help="Learning rate for training" + ) + parser.add_argument( + "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" + ) + + parser.add_argument("prompt") + parser.add_argument("image_folder") + + args = parser.parse_args() + + args.progress_prompt = args.progress_prompt or args.prompt + + flux = FluxPipeline("flux-" + args.model) + flux.ensure_models_are_loaded() + flux.flow.freeze() + linear_to_lora_layers(flux) + + trainable_params = tree_reduce( + lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 + ) + print(f"Training {trainable_params / 1024**2:.3f}M parameters") + + optimizer = optim.Adam(learning_rate=args.learning_rate) + state = [flux.flow.state, optimizer.state, mx.random.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(t5_tokens, clip_tokens, x, guidance): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + t5_tokens, clip_tokens, x, guidance + ) + optimizer.update(flux.flow, grads) + + return loss + + print("Encoding training images to latent space") + x = extract_latent_vectors(flux, args.image_folder) + t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x)) + guidance = mx.full((args.batch_size,), 4.0, dtype=flux.dtype) + + # An initial generation to compare + generate_progress_images(0, flux, args) + + losses = [] + tic = time.time() + for i in range(args.iterations): + indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype( + mx.uint32 + ) + loss = step(t5_tokens[indices], clip_tokens[indices], x[indices], guidance) + mx.eval(loss, state) + losses.append(loss.item()) + + if (i + 1) % 10 == 0: + toc = time.time() + peak_mem = mx.metal.get_peak_memory() / 1024**3 + print( + f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " + f"It/s: {10 / (toc - tic):.3f} " + f"Peak mem: {peak_mem:.3f} GB" + ) + + if (i + 1) % args.progress_every == 0: + generate_progress_images(i + 1, flux, args) + + if (i + 1) % args.checkpoint_every == 0: + pass + # save_checkpoints(i + 1, sd, args) + + if (i + 1) % 10 == 0: + losses = [] + tic = time.time() diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 0d5f1b70..b4f63d99 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -36,6 +36,11 @@ class FluxPipeline: self.t5.parameters(), ) + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + def _prepare_latent_images(self, x): b, h, w, c = x.shape @@ -56,16 +61,14 @@ class FluxPipeline: return x, x_ids - def _prepare_conditioning(self, n_images, text): + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): # Prepare the text features - t5_tokens = self.t5_tokenizer.encode(text) txt = self.t5(t5_tokens) if len(txt) == 1 and n_images > 1: txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) # Prepare the clip text features - clip_tokens = self.clip_tokenizer.encode(text) vec = self.clip(clip_tokens).pooled_output if len(vec) == 1 and n_images > 1: vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) @@ -131,8 +134,13 @@ class FluxPipeline: x_T, x_ids = self._prepare_latent_images(x_T) # Get the conditioning - txt, txt_ids, vec = self._prepare_conditioning(n_images, text) + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop yield from self._denoising_loop( x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance ) @@ -142,6 +150,38 @@ class FluxPipeline: x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) x = self.ae.decode(x) - x = (mx.clip(x + 1, 0, 2) * 127.5).astype(mx.uint8) + return mx.clip(x + 1, 0, 2) * 0.5 - return x + def training_loss( + self, + t5_tokens: mx.array, + clip_tokens: mx.array, + x_0: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = self.t5(t5_tokens) + txt_ids = mx.zeros(t5_tokens.shape + (3,), dtype=mx.int32) + vec = self.clip(clip_tokens).pooled_output + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process (we use rf/lognorm(0, 1)) + t = mx.sigmoid(mx.random.normal(shape=(len(x_0),), dtype=self.dtype)) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred - (eps - x_0)).square().mean() diff --git a/flux/flux/autoencoder.py b/flux/flux/autoencoder.py index 9d470cb9..2becfd26 100644 --- a/flux/flux/autoencoder.py +++ b/flux/flux/autoencoder.py @@ -296,14 +296,9 @@ class Decoder(nn.Module): class DiagonalGaussian(nn.Module): - def __init__(self, sample: bool = True, chunk_dim: int = 1): - super().__init__() - self.sample = sample - self.chunk_dim = chunk_dim - def __call__(self, z: mx.array): - mean, logvar = mx.split(z, 2, axis=self.chunk_dim) - if self.sample: + mean, logvar = mx.split(z, 2, axis=-1) + if self.training: std = mx.exp(0.5 * logvar) eps = mx.random.normal(shape=z.shape, dtype=z.dtype) return mean + std * eps diff --git a/flux/flux/lora.py b/flux/flux/lora.py new file mode 100644 index 00000000..785e5446 --- /dev/null +++ b/flux/flux/lora.py @@ -0,0 +1,74 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + + +class LoRALinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Linear, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + output_dims, input_dims = linear.weight.shape + lora_lin = LoRALinear( + input_dims=input_dims, + output_dims=output_dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self, de_quantize: bool = False): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + dtype = weight.dtype + + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, r), + ) + self.lora_b = mx.zeros(shape=(r, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (self.dropout(x) @ self.lora_a) @ self.lora_b + return y + (self.scale * z).astype(x.dtype) diff --git a/flux/txt2image.py b/flux/txt2image.py new file mode 100644 index 00000000..0c5ab68d --- /dev/null +++ b/flux/txt2image.py @@ -0,0 +1,116 @@ +import argparse + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from PIL import Image +from tqdm import tqdm + +from flux import FluxPipeline + + +def to_latent_size(image_size): + h, w = image_size + h = ((h + 15) // 16) * 16 + w = ((w + 15) // 16) * 16 + + if (h, w) != image_size: + print( + "Warning: The image dimensions need to be divisible by 16px. " + f"Changing size to {h}x{w}." + ) + + return (h // 8, w // 8) + + +def quantization_predicate(name, m): + return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate images from a textual prompt using stable diffusion" + ) + parser.add_argument("prompt") + parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") + parser.add_argument("--n_images", type=int, default=4) + parser.add_argument( + "--image_size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512) + ) + parser.add_argument("--steps", type=int) + parser.add_argument("--guidance", type=float, default=4.0) + parser.add_argument("--n_rows", type=int, default=1) + parser.add_argument("--decoding_batch_size", type=int, default=1) + parser.add_argument("--quantize", "-q", action="store_true") + parser.add_argument("--preload-models", action="store_true") + parser.add_argument("--output", default="out.png") + parser.add_argument("--seed", type=int) + parser.add_argument("--verbose", "-v", action="store_true") + args = parser.parse_args() + + # Load the models + flux = FluxPipeline("flux-" + args.model) + args.steps = args.steps or (50 if args.model == "dev" else 2) + + if args.quantize: + nn.quantize(flux.flow, class_predicate=quantization_predicate) + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + + if args.preload_models: + sd.ensure_models_are_loaded() + + # Make the generator + latent_size = to_latent_size(args.image_size) + latents = flux.generate_latents( + args.prompt, + n_images=args.n_images, + num_steps=args.steps, + latent_size=latent_size, + guidance=args.guidance, + seed=args.seed, + ) + + # First we get and eval the conditioning + conditioning = next(latents) + mx.eval(conditioning) + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 + + # The following is not necessary but it may help in memory constrained + # systems by reusing the memory kept by the text encoders. + del flux.t5 + del flux.clip + + # Actual denoising loop + for x_t in tqdm(latents, total=args.steps): + mx.eval(x_t) + + # The following is not necessary but it may help in memory constrained + # systems by reusing the memory kept by the flow transformer. + del flux.flow + peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 + + # Decode them into images + decoded = [] + for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): + decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) + mx.eval(decoded[-1]) + peak_mem_overall = mx.metal.get_peak_memory() / 1024**3 + + # Arrange them on a grid + x = mx.concatenate(decoded, axis=0) + x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(args.output) + + # Report the peak memory used during generation + if args.verbose: + print(f"Peak memory used for the text: {peak_mem_generation:.3f}GB") + print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") + print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") diff --git a/stable_diffusion/dreambooth.py b/stable_diffusion/dreambooth.py index 5d6922c2..7be12aa1 100644 --- a/stable_diffusion/dreambooth.py +++ b/stable_diffusion/dreambooth.py @@ -151,7 +151,7 @@ if __name__ == "__main__": help="Save the model every CHECKPOINT_EVERY steps", ) parser.add_argument( - "--learning_rate", type=float, default="1e-6", help="Learning rate for training" + "--learning-rate", type=float, default="1e-6", help="Learning rate for training" ) parser.add_argument( "--predict-x0", @@ -201,7 +201,7 @@ if __name__ == "__main__": sd = StableDiffusion(args.model) sd.ensure_models_are_loaded() - optimizer = optim.Adam(learning_rate=1e-6) + optimizer = optim.Adam(learning_rate=args.learning_rate) def loss_fn(params, text, x, weights): sd.unet.update(params["unet"])