From f491d473a332fc3ff9daf74be2bd154c0f9231b5 Mon Sep 17 00:00:00 2001 From: madroid Date: Wed, 16 Oct 2024 01:37:45 +0800 Subject: [PATCH] FLUX: Optimize dataset loading logic (#1038) --- flux/README.md | 47 ++++---- flux/dreambooth.py | 121 +++------------------ flux/flux/__init__.py | 239 +--------------------------------------- flux/flux/datasets.py | 75 +++++++++++++ flux/flux/flux.py | 246 ++++++++++++++++++++++++++++++++++++++++++ flux/flux/trainer.py | 98 +++++++++++++++++ 6 files changed, 461 insertions(+), 365 deletions(-) create mode 100644 flux/flux/datasets.py create mode 100644 flux/flux/flux.py create mode 100644 flux/flux/trainer.py diff --git a/flux/README.md b/flux/README.md index 0496c71b..1a17e386 100644 --- a/flux/README.md +++ b/flux/README.md @@ -21,8 +21,9 @@ The dependencies are minimal, namely: - `huggingface-hub` to download the checkpoints. - `regex` for the tokenization -- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script +- `tqdm`, `PIL`, and `numpy` for the scripts - `sentencepiece` for the T5 tokenizer +- `datasets` for using an HF dataset directly You can install all of the above with the `requirements.txt` as follows: @@ -118,17 +119,12 @@ Finetuning The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell but ymmv) on a provided image dataset. The dataset folder must have an -`index.json` file with the following format: +`train.jsonl` file with the following format: -```json -{ - "data": [ - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - ... - ] -} +```jsonl +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +... ``` The training script by default trains for 600 iterations with a batch size of @@ -150,19 +146,15 @@ The training images are the following 5 images [^2]: ![dog6](static/dog6.png) -We start by making the following `index.json` file and placing it in the same +We start by making the following `train.jsonl` file and placing it in the same folder as the images. -```json -{ - "data": [ - {"image": "00.jpg", "text": "A photo of sks dog"}, - {"image": "01.jpg", "text": "A photo of sks dog"}, - {"image": "02.jpg", "text": "A photo of sks dog"}, - {"image": "03.jpg", "text": "A photo of sks dog"}, - {"image": "04.jpg", "text": "A photo of sks dog"} - ] -} +```jsonl +{"image": "00.jpg", "prompt": "A photo of sks dog"} +{"image": "01.jpg", "prompt": "A photo of sks dog"} +{"image": "02.jpg", "prompt": "A photo of sks dog"} +{"image": "03.jpg", "prompt": "A photo of sks dog"} +{"image": "04.jpg", "prompt": "A photo of sks dog"} ``` Subsequently we finetune FLUX using the following command: @@ -175,6 +167,17 @@ python dreambooth.py \ path/to/dreambooth/dataset/dog6 ``` + +Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning. + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + mlx-community/dreambooth-dog6 +``` + The training requires approximately 50GB of RAM and on an M2 Ultra it takes a bit more than 1 hour. diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 4a4dbb08..48dcad47 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -1,7 +1,6 @@ # Copyright © 2024 Apple Inc. import argparse -import json import time from functools import partial from pathlib import Path @@ -13,105 +12,8 @@ import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from tqdm import tqdm -from flux import FluxPipeline - - -class FinetuningDataset: - def __init__(self, flux, args): - self.args = args - self.flux = flux - self.dataset_base = Path(args.dataset) - dataset_index = self.dataset_base / "index.json" - if not dataset_index.exists(): - raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset") - with open(dataset_index, "r") as f: - self.index = json.load(f) - - self.latents = [] - self.t5_features = [] - self.clip_features = [] - - def _random_crop_resize(self, img): - resolution = self.args.resolution - width, height = img.size - - a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() - - # Random crop the input image between 0.8 to 1.0 of its original dimensions - crop_size = ( - max((0.8 + 0.2 * a) * width, resolution[0]), - max((0.8 + 0.2 * a) * height, resolution[1]), - ) - pan = (width - crop_size[0], height - crop_size[1]) - img = img.crop( - ( - pan[0] * b, - pan[1] * c, - crop_size[0] + pan[0] * b, - crop_size[1] + pan[1] * c, - ) - ) - - # Fit the largest rectangle with the ratio of resolution in the image - # rectangle. - width, height = crop_size - ratio = resolution[0] / resolution[1] - r1 = (height * ratio, height) - r2 = (width, width / ratio) - r = r1 if r1[0] <= width else r2 - img = img.crop( - ( - (width - r[0]) / 2, - (height - r[1]) / 2, - (width + r[0]) / 2, - (height + r[1]) / 2, - ) - ) - - # Finally resize the image to resolution - img = img.resize(resolution, Image.LANCZOS) - - return mx.array(np.array(img)) - - def encode_images(self): - """Encode the images in the latent space to prepare for training.""" - self.flux.ae.eval() - for sample in tqdm(self.index["data"]): - input_img = Image.open(self.dataset_base / sample["image"]) - for i in range(self.args.num_augmentations): - img = self._random_crop_resize(input_img) - img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 - x_0 = self.flux.ae.encode(img[None]) - x_0 = x_0.astype(self.flux.dtype) - mx.eval(x_0) - self.latents.append(x_0) - - def encode_prompts(self): - """Pre-encode the prompts so that we don't recompute them during - training (doesn't allow finetuning the text encoders).""" - for sample in tqdm(self.index["data"]): - t5_tok, clip_tok = self.flux.tokenize([sample["text"]]) - t5_feat = self.flux.t5(t5_tok) - clip_feat = self.flux.clip(clip_tok).pooled_output - mx.eval(t5_feat, clip_feat) - self.t5_features.append(t5_feat) - self.clip_features.append(clip_feat) - - def iterate(self, batch_size): - xs = mx.concatenate(self.latents) - t5 = mx.concatenate(self.t5_features) - clip = mx.concatenate(self.clip_features) - mx.eval(xs, t5, clip) - n_aug = self.args.num_augmentations - while True: - x_indices = mx.random.permutation(len(self.latents)) - c_indices = x_indices // n_aug - for i in range(0, len(self.latents), batch_size): - x_i = x_indices[i : i + batch_size] - c_i = c_indices[i : i + batch_size] - yield xs[x_i], t5[c_i], clip[c_i] +from flux import FluxPipeline, Trainer, load_dataset def generate_progress_images(iteration, flux, args): @@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args): ) -if __name__ == "__main__": +def setup_arg_parser(): + """Set up and return the argument parser.""" parser = argparse.ArgumentParser( description="Finetune Flux to generate images with a specific subject" ) @@ -247,7 +150,11 @@ if __name__ == "__main__": ) parser.add_argument("dataset") + return parser + +if __name__ == "__main__": + parser = setup_arg_parser() args = parser.parse_args() # Load the model and set it up for LoRA training. We use the same random @@ -267,7 +174,7 @@ if __name__ == "__main__": trainable_params = tree_reduce( lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 ) - print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) + print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True) # Set up the optimizer and training steps. The steps are a bit verbose to # support gradient accumulation together with compilation. @@ -340,10 +247,10 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - print("Create the training dataset.", flush=True) - dataset = FinetuningDataset(flux, args) - dataset.encode_images() - dataset.encode_prompts() + dataset = load_dataset(args.dataset) + trainer = Trainer(flux, dataset, args) + trainer.encode_dataset() + guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) # An initial generation to compare @@ -352,7 +259,7 @@ if __name__ == "__main__": grads = None losses = [] tic = time.time() - for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): + for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) mx.eval(loss, grads, state) losses.append(loss.item()) @@ -361,7 +268,7 @@ if __name__ == "__main__": toc = time.time() peak_mem = mx.metal.get_peak_memory() / 1024**3 print( - f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " + f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " f"Peak mem: {peak_mem:.3f} GB", flush=True, diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 8d39d605..b1122d75 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -1,16 +1,10 @@ # Copyright © 2024 Apple Inc. -import math -import time -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten -from tqdm import tqdm - +from .datasets import Dataset, load_dataset +from .flux import FluxPipeline from .lora import LoRALinear from .sampler import FluxSampler +from .trainer import Trainer from .utils import ( load_ae, load_clip, @@ -19,230 +13,3 @@ from .utils import ( load_t5, load_t5_tokenizer, ) - - -class FluxPipeline: - def __init__(self, name: str, t5_padding: bool = True): - self.dtype = mx.bfloat16 - self.name = name - self.t5_padding = t5_padding - - self.ae = load_ae(name) - self.flow = load_flow_model(name) - self.clip = load_clip(name) - self.clip_tokenizer = load_clip_tokenizer(name) - self.t5 = load_t5(name) - self.t5_tokenizer = load_t5_tokenizer(name) - self.sampler = FluxSampler(name) - - def ensure_models_are_loaded(self): - mx.eval( - self.ae.parameters(), - self.flow.parameters(), - self.clip.parameters(), - self.t5.parameters(), - ) - - def reload_text_encoders(self): - self.t5 = load_t5(self.name) - self.clip = load_clip(self.name) - - def tokenize(self, text): - t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) - clip_tokens = self.clip_tokenizer.encode(text) - return t5_tokens, clip_tokens - - def _prepare_latent_images(self, x): - b, h, w, c = x.shape - - # Pack the latent image to 2x2 patches - x = x.reshape(b, h // 2, 2, w // 2, 2, c) - x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) - - # Create positions ids used to positionally encode each patch. Due to - # the way RoPE works, this results in an interesting positional - # encoding where parts of the feature are holding different positional - # information. Namely, the first part holds information independent of - # the spatial position (hence 0s), the 2nd part holds vertical spatial - # information and the last one horizontal. - i = mx.zeros((h // 2, w // 2), dtype=mx.int32) - j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") - x_ids = mx.stack([i, j, k], axis=-1) - x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) - - return x, x_ids - - def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): - # Prepare the text features - txt = self.t5(t5_tokens) - if len(txt) == 1 and n_images > 1: - txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) - txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) - - # Prepare the clip text features - vec = self.clip(clip_tokens).pooled_output - if len(vec) == 1 and n_images > 1: - vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) - - return txt, txt_ids, vec - - def _denoising_loop( - self, - x_t, - x_ids, - txt, - txt_ids, - vec, - num_steps: int = 35, - guidance: float = 4.0, - start: float = 1, - stop: float = 0, - ): - B = len(x_t) - - def scalar(x): - return mx.full((B,), x, dtype=self.dtype) - - guidance = scalar(guidance) - timesteps = self.sampler.timesteps( - num_steps, - x_t.shape[1], - start=start, - stop=stop, - ) - for i in range(num_steps): - t = timesteps[i] - t_prev = timesteps[i + 1] - - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=scalar(t), - guidance=guidance, - ) - x_t = self.sampler.step(pred, x_t, t, t_prev) - - yield x_t - - def generate_latents( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - ): - # Set the PRNG state - if seed is not None: - mx.random.seed(seed) - - # Create the latent variables - x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) - x_T, x_ids = self._prepare_latent_images(x_T) - - # Get the conditioning - t5_tokens, clip_tokens = self.tokenize(text) - txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) - - # Yield the conditioning for controlled evaluation by the caller - yield (x_T, x_ids, txt, txt_ids, vec) - - # Yield the latent sequences from the denoising loop - yield from self._denoising_loop( - x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance - ) - - def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): - h, w = latent_size - x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) - x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) - x = self.ae.decode(x) - return mx.clip(x + 1, 0, 2) * 0.5 - - def generate_images( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - reload_text_encoders: bool = True, - progress: bool = True, - ): - latents = self.generate_latents( - text, n_images, num_steps, guidance, latent_size, seed - ) - mx.eval(next(latents)) - - if reload_text_encoders: - self.reload_text_encoders() - - for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): - mx.eval(x_t) - - images = [] - for i in tqdm(range(len(x_t)), disable=not progress): - images.append(self.decode(x_t[i : i + 1])) - mx.eval(images[-1]) - images = mx.concatenate(images, axis=0) - mx.eval(images) - - return images - - def training_loss( - self, - x_0: mx.array, - t5_features: mx.array, - clip_features: mx.array, - guidance: mx.array, - ): - # Get the text conditioning - txt = t5_features - txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) - vec = clip_features - - # Prepare the latent input - x_0, x_ids = self._prepare_latent_images(x_0) - - # Forward process - t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=self.dtype) - x_t = self.sampler.add_noise(x_0, t, noise=eps) - x_t = mx.stop_gradient(x_t) - - # Do the denoising - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t, - guidance=guidance, - ) - - return (pred + x_0 - eps).square().mean() - - def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): - """Swap the linear layers in the transformer blocks with LoRA layers.""" - all_blocks = self.flow.double_blocks + self.flow.single_blocks - all_blocks.reverse() - num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) - for i, block in zip(range(num_blocks), all_blocks): - loras = [] - for name, module in block.named_modules(): - if isinstance(module, nn.Linear): - loras.append((name, LoRALinear.from_base(module, r=rank))) - block.update_modules(tree_unflatten(loras)) - - def fuse_lora_layers(self): - fused_layers = [] - for name, module in self.flow.named_modules(): - if isinstance(module, LoRALinear): - fused_layers.append((name, module.fuse())) - self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py new file mode 100644 index 00000000..d31a09f1 --- /dev/null +++ b/flux/flux/datasets.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path + +from PIL import Image + + +class Dataset: + def __getitem__(self, index: int): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + +class LocalDataset(Dataset): + prompt_key = "prompt" + + def __init__(self, dataset: str, data_file): + self.dataset_base = Path(dataset) + with open(data_file, "r") as fid: + self._data = [json.loads(l) for l in fid] + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int): + item = self._data[index] + image = Image.open(self.dataset_base / item["image"]) + return image, item[self.prompt_key] + + +class LegacyDataset(LocalDataset): + prompt_key = "text" + + def __init__(self, dataset: str): + self.dataset_base = Path(dataset) + with open(self.dataset_base / "index.json") as f: + self._data = json.load(f)["data"] + + +class HuggingFaceDataset(Dataset): + + def __init__(self, dataset: str): + from datasets import load_dataset as hf_load_dataset + + self._df = hf_load_dataset(dataset)["train"] + + def __len__(self): + return len(self._df) + + def __getitem__(self, index: int): + item = self._df[index] + return item["image"], item["prompt"] + + +def load_dataset(dataset: str): + dataset_base = Path(dataset) + data_file = dataset_base / "train.jsonl" + legacy_file = dataset_base / "index.json" + + if data_file.exists(): + print(f"Load the local dataset {data_file} .", flush=True) + dataset = LocalDataset(dataset, data_file) + elif legacy_file.exists(): + print(f"Load the local dataset {legacy_file} .") + print() + print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") + print(" See the README for details.") + print(flush=True) + dataset = LegacyDataset(dataset) + else: + print(f"Load the Hugging Face dataset {dataset} .", flush=True) + dataset = HuggingFaceDataset(dataset) + + return dataset diff --git a/flux/flux/flux.py b/flux/flux/flux.py new file mode 100644 index 00000000..3fd044ac --- /dev/null +++ b/flux/flux/flux.py @@ -0,0 +1,246 @@ +# Copyright © 2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name) + self.clip = load_clip(self.name) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py new file mode 100644 index 00000000..40a126e8 --- /dev/null +++ b/flux/flux/trainer.py @@ -0,0 +1,98 @@ +import mlx.core as mx +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm + +from .datasets import Dataset +from .flux import FluxPipeline + + +class Trainer: + + def __init__(self, flux: FluxPipeline, dataset: Dataset, args): + self.flux = flux + self.dataset = dataset + self.args = args + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * b) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * c, + pan[1] * d, + crop_size[0] + pan[0] * c, + crop_size[1] + pan[1] * d, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): + for i in range(num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def _encode_prompt(self, prompt): + t5_tok, clip_tok = self.flux.tokenize([prompt]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def encode_dataset(self): + """Encode the images & prompt in the latent space to prepare for training.""" + self.flux.ae.eval() + for image, prompt in tqdm(self.dataset, desc="encode dataset"): + self._encode_image(image, self.args.num_augmentations) + self._encode_prompt(prompt) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i]