Cleanup the dreambooth

This commit is contained in:
Angelos Katharopoulos
2024-10-09 23:26:39 -07:00
parent 446d8b6439
commit f538394eec
2 changed files with 66 additions and 70 deletions

View File

@@ -10,7 +10,7 @@ import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_map, tree_reduce, tree_unflatten
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
from PIL import Image
from tqdm import tqdm
@@ -77,47 +77,6 @@ class FinetuningDataset:
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def generate_prior_preservation(self):
"""Generate some images and mix them with the training images to avoid
overfitting to the dataset."""
prior_preservation = self.index.get("prior_preservation", None)
if not prior_preservation:
return
# Select a random set of prompts from the available ones
prior_prompts = mx.random.randint(
low=0,
high=len(prior_preservation["prompts"]),
shape=(prior_preservation["n_images"],),
).tolist()
# For each prompt
for prompt_idx in tqdm(prior_prompts):
# Create the generator
latents = self.flux.generate_latents(
prior_preservation["prompts"][prompt_idx],
num_steps=prior_preservation.get(
"num_steps", 2 if "schnell" in self.flux.name else 35
),
)
# Extract the t5 and clip features
conditioning = next(latents)
mx.eval(conditioning)
t5_feat = conditioning[2]
clip_feat = conditioning[4]
del conditioning
# Do the denoising
for x_t in latents:
mx.eval(x_t)
# Append everything in the data lists
self.latents.append(x_t)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
@@ -129,6 +88,7 @@ class FinetuningDataset:
def linear_to_lora_layers(flux, args):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
rank = args.lora_rank
all_blocks = flux.flow.double_blocks + flux.flow.single_blocks
all_blocks.reverse()
@@ -141,32 +101,33 @@ def linear_to_lora_layers(flux, args):
block.update_modules(tree_unflatten(loras))
def decode_latents(flux, x):
decoded = []
for i in tqdm(range(len(x))):
decoded.append(flux.decode(x[i : i + 1]))
mx.eval(decoded[-1])
return mx.concatenate(decoded, axis=0)
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
with random_state(seed):
latents = flux.generate_latents(
prompt,
n_images=n_images,
num_steps=steps,
)
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
mx.eval(x_t)
return x_t
def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
with random_state(seed):
latents = flux.generate_latents(
prompt,
n_images=n_images,
num_steps=steps,
)
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
mx.eval(x_t)
return x_t
def decode_latents(flux, x):
decoded = []
for i in tqdm(range(len(x))):
decoded.append(flux.decode(x[i : i + 1]))
mx.eval(decoded[-1])
return mx.concatenate(decoded, axis=0)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"out_{iteration:03d}.png"
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate the latent vectors using diffusion
n_images = 4
latents = generate_latents(
@@ -177,6 +138,9 @@ def generate_progress_images(iteration, flux, args):
seed=42 + mx.distributed.init().rank(),
)
# Reload the text encoders to reduce the memory use during training
flux.reload_text_encoders()
# Arrange them on a grid
n_rows = 2
x = decode_latents(flux, latents)
@@ -192,6 +156,22 @@ def generate_progress_images(iteration, flux, args):
im.save(out_file)
def save_adapters(iteration, flux, args):
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
print(f"Saving {str(out_file)}")
mx.save_safetensors(
str(out_file),
dict(tree_flatten(flux.flow.trainable_parameters())),
metadata={
"lora_rank": args.lora_rank,
"lora_blocks": args.lora_blocks,
},
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
@@ -206,6 +186,9 @@ if __name__ == "__main__":
],
help="Which flux model to train",
)
parser.add_argument(
"--guidance", type=float, default=4.0, help="The guidance factor to use."
)
parser.add_argument(
"--iterations",
type=int,
@@ -248,7 +231,10 @@ if __name__ == "__main__":
help="Save the model every CHECKPOINT_EVERY steps",
)
parser.add_argument(
"--lora-blocks", type=int, default=-1, help="Train the last LORA_BLOCKS blocks"
"--lora-blocks",
type=int,
default=-1,
help="Train the last LORA_BLOCKS transformer blocks",
)
parser.add_argument(
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
@@ -277,17 +263,22 @@ if __name__ == "__main__":
# setting.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same
# initial weights.
flux = FluxPipeline("flux-" + args.model)
flux.ensure_models_are_loaded()
flux.flow.freeze()
with random_state(0x0F0F0F0F):
linear_to_lora_layers(flux, args)
# Report how many parameters we are training
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
cosine = optim.cosine_decay(
args.learning_rate, args.iterations // args.grad_accumulate
@@ -331,6 +322,9 @@ if __name__ == "__main__":
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
if prev_grads is None:
if perform_step:
@@ -354,8 +348,7 @@ if __name__ == "__main__":
dataset = FinetuningDataset(flux, args)
dataset.encode_images()
dataset.encode_prompts()
dataset.generate_prior_preservation()
guidance = mx.full((args.batch_size,), 4.0, dtype=flux.dtype)
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
generate_progress_images(0, flux, args)
@@ -382,8 +375,7 @@ if __name__ == "__main__":
generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0:
pass
# save_checkpoints(i + 1, sd, args)
save_adapters(i + 1, flux, args)
if (i + 1) % 10 == 0:
losses = []

View File

@@ -36,6 +36,10 @@ class FluxPipeline:
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text)
clip_tokens = self.clip_tokenizer.encode(text)