From bb8436a4419b7c91ebda9338c9a99b5f7b2c617c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 8 Oct 2024 01:09:24 -0700 Subject: [PATCH] Update dataset --- flux/dreambooth.py | 204 ++++++++++++++++++++++++++++------------ flux/flux/__init__.py | 14 +-- flux/flux/tokenizers.py | 20 +++- flux/flux/utils.py | 2 +- 4 files changed, 168 insertions(+), 72 deletions(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index e9995aea..e6c895ba 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -1,5 +1,7 @@ import argparse +import json import time +from contextlib import contextmanager from functools import partial from pathlib import Path @@ -16,6 +18,106 @@ from flux import FluxPipeline from flux.lora import LoRALinear +@contextmanager +def random_state(seed=None): + s = mx.random.state[0] + try: + if seed is not None: + mx.random.seed(seed) + yield + finally: + mx.random.state[0] = s + + +class FinetuningDataset: + def __init__(self, flux, args): + self.args = args + self.flux = flux + self.dataset_base = Path(args.dataset) + dataset_index = self.dataset_base / "index.json" + if not dataset_index.exists(): + raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset") + with open(dataset_index, "r") as f: + self.index = json.load(f) + + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def encode_images(self): + """Encode the images in the latent space to prepare for training.""" + self.flux.ae.eval() + for sample in tqdm(self.index["data"]): + img = Image.open(self.dataset_base / sample["image"]) + img = mx.array(np.array(img)) + img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def encode_prompts(self): + """Pre-encode the prompts so that we don't recompute them during + training (doesn't allow finetuning the text encoders).""" + for sample in tqdm(self.index["data"]): + t5_tok, clip_tok = self.flux.tokenize([sample["text"]]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def generate_prior_preservation(self): + """Generate some images and mix them with the training images to avoid + overfitting to the dataset.""" + + prior_preservation = self.index.get("prior_preservation", None) + if not prior_preservation: + return + + # Select a random set of prompts from the available ones + prior_prompts = mx.random.randint( + low=0, + high=len(prior_preservation["prompts"]), + shape=(prior_preservation["n_images"],), + ).tolist() + + # For each prompt + for prompt_idx in tqdm(prior_prompts): + # Create the generator + latents = self.flux.generate_latents( + prior_preservation["prompts"][prompt_idx], + num_steps=prior_preservation.get( + "num_steps", 2 if "schnell" in self.flux.name else 35 + ), + ) + + # Extract the t5 and clip features + conditioning = next(latents) + mx.eval(conditioning) + t5_feat = conditioning[2] + clip_feat = conditioning[4] + del conditioning + + # Do the denoising + for x_t in latents: + mx.eval(x_t) + + # Append everything in the data lists + self.latents.append(x_t) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def iterate(self, batch_size): + while True: + indices = mx.random.randint(0, len(self.latents), (batch_size,)).tolist() + x = mx.concatenate([self.latents[i] for i in indices]) + t5 = mx.concatenate([self.t5_features[i] for i in indices]) + clip = mx.concatenate([self.clip_features[i] for i in indices]) + mx.eval(x, t5, clip) + yield x, t5, clip + + def linear_to_lora_layers(flux, args): lora_layers = [] rank = args.lora_rank @@ -27,20 +129,6 @@ def linear_to_lora_layers(flux, args): flux.flow.update_modules(tree_unflatten(lora_layers)) -def extract_latent_vectors(flux, image_folder): - flux.ae.eval() - latents = [] - for image in tqdm(Path(image_folder).iterdir()): - img = Image.open(image) - img = mx.array(np.array(img)) - img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1 - x_0 = flux.ae.encode(img[None]) - x_0 = x_0.astype(flux.dtype) - mx.eval(x_0) - latents.append(x_0) - return mx.concatenate(latents) - - def decode_latents(flux, x): decoded = [] for i in tqdm(range(len(x))): @@ -50,32 +138,23 @@ def decode_latents(flux, x): def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True): - latents = flux.generate_latents( - prompt, - n_images=n_images, - num_steps=steps, - seed=seed, - ) - for x_t in tqdm(latents, total=args.progress_steps, leave=leave): - mx.eval(x_t) + with random_state(seed): + latents = flux.generate_latents( + prompt, + n_images=n_images, + num_steps=steps, + ) + for x_t in tqdm(latents, total=args.progress_steps, leave=leave): + mx.eval(x_t) - return x_t - - -def iterate_batches(t5_tokens, clip_tokens, x, batch_size): - while True: - indices = mx.random.randint(0, len(x), (batch_size,)) - t5_i = t5_tokens[indices] - clip_i = clip_tokens[indices] - x_i = x[indices] - yield t5_i, clip_i, x_i + return x_t def generate_progress_images(iteration, flux, args): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"out_{iteration:03d}.png" - print(f"Generating {str(out_file)}") + print(f"Generating {str(out_file)}", flush=True) # Generate the latent vectors using diffusion n_images = 4 latents = generate_latents( @@ -118,7 +197,7 @@ if __name__ == "__main__": parser.add_argument( "--iterations", type=int, - default=400, + default=1000, help="How many iterations to train for", ) parser.add_argument( @@ -129,6 +208,7 @@ if __name__ == "__main__": ) parser.add_argument( "--progress-prompt", + required=True, help="Use this prompt when generating images for evaluation", ) parser.add_argument( @@ -156,7 +236,7 @@ if __name__ == "__main__": "--warmup-steps", type=int, default=100, help="Learning rate warmup" ) parser.add_argument( - "--learning-rate", type=float, default="1e-4", help="Learning rate for training" + "--learning-rate", type=float, default="1e-5", help="Learning rate for training" ) parser.add_argument( "--grad-accumulate", @@ -168,22 +248,24 @@ if __name__ == "__main__": "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" ) - parser.add_argument("prompt") - parser.add_argument("image_folder") + parser.add_argument("dataset") args = parser.parse_args() - args.progress_prompt = args.progress_prompt or args.prompt + # Initialize the seed but different per worker if we are in a distributed + # setting. + mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank()) flux = FluxPipeline("flux-" + args.model) flux.ensure_models_are_loaded() flux.flow.freeze() - linear_to_lora_layers(flux, args) + with random_state(0x0F0F0F0F): + linear_to_lora_layers(flux, args) trainable_params = tree_reduce( lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 ) - print(f"Training {trainable_params / 1024**2:.3f}M parameters") + print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps) cosine = optim.cosine_decay( @@ -194,9 +276,9 @@ if __name__ == "__main__": state = [flux.flow.state, optimizer.state, mx.random.state] @partial(mx.compile, inputs=state, outputs=state) - def single_step(t5_tokens, clip_tokens, x, guidance): + def single_step(x, t5_feat, clip_feat, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( - t5_tokens, clip_tokens, x, guidance + x, t5_feat, clip_feat, guidance ) grads = average_gradients(grads) optimizer.update(flux.flow, grads) @@ -204,25 +286,23 @@ if __name__ == "__main__": return loss @partial(mx.compile, inputs=state, outputs=state) - def compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance): + def compute_loss_and_grads(t5_feat, clip_feat, x, guidance): return nn.value_and_grad(flux.flow, flux.training_loss)( - t5_tokens, clip_tokens, x, guidance + x, t5_feat, clip_feat, guidance ) @partial(mx.compile, inputs=state, outputs=state) - def compute_loss_and_accumulate_grads( - t5_tokens, clip_tokens, x, guidance, prev_grads - ): + def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( - t5_tokens, clip_tokens, x, guidance + x, t5_feat, clip_feat, guidance ) grads = tree_map(lambda a, b: a + b, prev_grads, grads) return loss, grads @partial(mx.compile, inputs=state, outputs=state) - def grad_accumulate_and_step(t5_tokens, clip_tokens, x, guidance, prev_grads): + def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( - t5_tokens, clip_tokens, x, guidance + x, t5_feat, clip_feat, guidance ) grads = tree_map(lambda a, b: a + b, prev_grads, grads) grads = average_gradients(grads) @@ -230,28 +310,30 @@ if __name__ == "__main__": return loss - def step(t5_tokens, clip_tokens, x, guidance, prev_grads, perform_step): + def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): if prev_grads is None: if perform_step: - return single_step(t5_tokens, clip_tokens, x, guidance), None + return single_step(x, t5_feat, clip_feat, guidance), None else: - return compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance) + return compute_loss_and_grads(x, t5_feat, clip_feat, guidance) else: if perform_step: return ( grad_accumulate_and_step( - t5_tokens, clip_tokens, x, guidance, prev_grads + x, t5_feat, clip_feat, x, guidance, prev_grads ), None, ) else: return compute_loss_and_accumulate_grads( - t5_tokens, clip_tokens, x, guidance, prev_grads + x, t5_feat, clip_feat, guidance, prev_grads ) - print("Encoding training images to latent space") - x = extract_latent_vectors(flux, args.image_folder) - t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x)) + print("Create the training dataset.", flush=True) + dataset = FinetuningDataset(flux, args) + dataset.encode_images() + dataset.encode_prompts() + dataset.generate_prior_preservation() guidance = mx.full((args.batch_size,), 4.0, dtype=flux.dtype) # An initial generation to compare @@ -260,8 +342,7 @@ if __name__ == "__main__": grads = None losses = [] tic = time.time() - batches = iterate_batches(t5_tokens, clip_tokens, x, args.batch_size) - for i, batch in zip(range(args.iterations), batches): + for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) mx.eval(loss, grads, state) losses.append(loss.item()) @@ -272,7 +353,8 @@ if __name__ == "__main__": print( f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " - f"Peak mem: {peak_mem:.3f} GB" + f"Peak mem: {peak_mem:.3f} GB", + flush=True, ) if (i + 1) % args.progress_every == 0: diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index b4f63d99..cb116dbb 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -126,8 +126,8 @@ class FluxPipeline: seed=None, ): # Set the PRNG state - seed = int(time.time()) if seed is None else seed - mx.random.seed(seed) + if seed is not None: + mx.random.seed(seed) # Create the latent variables x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) @@ -154,15 +154,15 @@ class FluxPipeline: def training_loss( self, - t5_tokens: mx.array, - clip_tokens: mx.array, x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, guidance: mx.array, ): # Get the text conditioning - txt = self.t5(t5_tokens) - txt_ids = mx.zeros(t5_tokens.shape + (3,), dtype=mx.int32) - vec = self.clip(clip_tokens).pooled_output + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features # Prepare the latent input x_0, x_ids = self._prepare_latent_images(x_0) diff --git a/flux/flux/tokenizers.py b/flux/flux/tokenizers.py index be295281..67820d48 100644 --- a/flux/flux/tokenizers.py +++ b/flux/flux/tokenizers.py @@ -118,8 +118,20 @@ class CLIPTokenizer: class T5Tokenizer: - def __init__(self, model_file): + def __init__(self, model_file, max_length=512): self._tokenizer = SentencePieceProcessor(model_file) + self.max_length = max_length + + @property + def pad(self): + try: + return self._tokenizer.id_to_piece(self.pad_token) + except IndexError: + return None + + @property + def pad_token(self): + return self._tokenizer.pad_id() @property def bos(self): @@ -143,9 +155,9 @@ class T5Tokenizer: def eos_token(self): return self._tokenizer.eos_id() - def tokenize(self, text, prepend_bos=True, append_eos=True): + def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True): if isinstance(text, list): - return [self.tokenize(t, prepend_bos, append_eos) for t in text] + return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text] tokens = self._tokenizer.encode(text) @@ -153,6 +165,8 @@ class T5Tokenizer: tokens = [self.bos_token] + tokens if append_eos and self.eos_token >= 0: tokens.append(self.eos_token) + if len(tokens) < self.max_length and self.pad_token >= 0: + tokens += [self.pad_token] * (self.max_length - len(tokens)) return tokens diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 7c8e9214..e506cda1 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -204,4 +204,4 @@ def load_clip_tokenizer(name: str): def load_t5_tokenizer(name: str): model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") - return T5Tokenizer(model_file) + return T5Tokenizer(model_file, 256 if "schnell" in name else 512)