mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Cleanup the dreambooth
This commit is contained in:
@@ -10,7 +10,7 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_map, tree_reduce, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -77,47 +77,6 @@ class FinetuningDataset:
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def generate_prior_preservation(self):
|
||||
"""Generate some images and mix them with the training images to avoid
|
||||
overfitting to the dataset."""
|
||||
|
||||
prior_preservation = self.index.get("prior_preservation", None)
|
||||
if not prior_preservation:
|
||||
return
|
||||
|
||||
# Select a random set of prompts from the available ones
|
||||
prior_prompts = mx.random.randint(
|
||||
low=0,
|
||||
high=len(prior_preservation["prompts"]),
|
||||
shape=(prior_preservation["n_images"],),
|
||||
).tolist()
|
||||
|
||||
# For each prompt
|
||||
for prompt_idx in tqdm(prior_prompts):
|
||||
# Create the generator
|
||||
latents = self.flux.generate_latents(
|
||||
prior_preservation["prompts"][prompt_idx],
|
||||
num_steps=prior_preservation.get(
|
||||
"num_steps", 2 if "schnell" in self.flux.name else 35
|
||||
),
|
||||
)
|
||||
|
||||
# Extract the t5 and clip features
|
||||
conditioning = next(latents)
|
||||
mx.eval(conditioning)
|
||||
t5_feat = conditioning[2]
|
||||
clip_feat = conditioning[4]
|
||||
del conditioning
|
||||
|
||||
# Do the denoising
|
||||
for x_t in latents:
|
||||
mx.eval(x_t)
|
||||
|
||||
# Append everything in the data lists
|
||||
self.latents.append(x_t)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
@@ -129,6 +88,7 @@ class FinetuningDataset:
|
||||
|
||||
|
||||
def linear_to_lora_layers(flux, args):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
rank = args.lora_rank
|
||||
all_blocks = flux.flow.double_blocks + flux.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
@@ -141,32 +101,33 @@ def linear_to_lora_layers(flux, args):
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
|
||||
def decode_latents(flux, x):
|
||||
decoded = []
|
||||
for i in tqdm(range(len(x))):
|
||||
decoded.append(flux.decode(x[i : i + 1]))
|
||||
mx.eval(decoded[-1])
|
||||
return mx.concatenate(decoded, axis=0)
|
||||
|
||||
|
||||
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
|
||||
with random_state(seed):
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=n_images,
|
||||
num_steps=steps,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
|
||||
mx.eval(x_t)
|
||||
|
||||
return x_t
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
"""Generate images to monitor the progress of the finetuning."""
|
||||
|
||||
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
|
||||
with random_state(seed):
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=n_images,
|
||||
num_steps=steps,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
|
||||
mx.eval(x_t)
|
||||
|
||||
return x_t
|
||||
|
||||
def decode_latents(flux, x):
|
||||
decoded = []
|
||||
for i in tqdm(range(len(x))):
|
||||
decoded.append(flux.decode(x[i : i + 1]))
|
||||
mx.eval(decoded[-1])
|
||||
return mx.concatenate(decoded, axis=0)
|
||||
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"out_{iteration:03d}.png"
|
||||
out_file = out_dir / f"{iteration:07d}_progress.png"
|
||||
print(f"Generating {str(out_file)}", flush=True)
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
n_images = 4
|
||||
latents = generate_latents(
|
||||
@@ -177,6 +138,9 @@ def generate_progress_images(iteration, flux, args):
|
||||
seed=42 + mx.distributed.init().rank(),
|
||||
)
|
||||
|
||||
# Reload the text encoders to reduce the memory use during training
|
||||
flux.reload_text_encoders()
|
||||
|
||||
# Arrange them on a grid
|
||||
n_rows = 2
|
||||
x = decode_latents(flux, latents)
|
||||
@@ -192,6 +156,22 @@ def generate_progress_images(iteration, flux, args):
|
||||
im.save(out_file)
|
||||
|
||||
|
||||
def save_adapters(iteration, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
|
||||
print(f"Saving {str(out_file)}")
|
||||
|
||||
mx.save_safetensors(
|
||||
str(out_file),
|
||||
dict(tree_flatten(flux.flow.trainable_parameters())),
|
||||
metadata={
|
||||
"lora_rank": args.lora_rank,
|
||||
"lora_blocks": args.lora_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Finetune Flux to generate images with a specific subject"
|
||||
@@ -206,6 +186,9 @@ if __name__ == "__main__":
|
||||
],
|
||||
help="Which flux model to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance", type=float, default=4.0, help="The guidance factor to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
@@ -248,7 +231,10 @@ if __name__ == "__main__":
|
||||
help="Save the model every CHECKPOINT_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-blocks", type=int, default=-1, help="Train the last LORA_BLOCKS blocks"
|
||||
"--lora-blocks",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Train the last LORA_BLOCKS transformer blocks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
|
||||
@@ -277,17 +263,22 @@ if __name__ == "__main__":
|
||||
# setting.
|
||||
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
||||
|
||||
# Load the model and set it up for LoRA training. We use the same random
|
||||
# state when creating the LoRA layers so all workers will have the same
|
||||
# initial weights.
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
flux.ensure_models_are_loaded()
|
||||
flux.flow.freeze()
|
||||
with random_state(0x0F0F0F0F):
|
||||
linear_to_lora_layers(flux, args)
|
||||
|
||||
# Report how many parameters we are training
|
||||
trainable_params = tree_reduce(
|
||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||
)
|
||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
|
||||
|
||||
# Set up the optimizer and training steps. The steps are a bit verbose to
|
||||
# support gradient accumulation together with compilation.
|
||||
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
|
||||
cosine = optim.cosine_decay(
|
||||
args.learning_rate, args.iterations // args.grad_accumulate
|
||||
@@ -331,6 +322,9 @@ if __name__ == "__main__":
|
||||
|
||||
return loss
|
||||
|
||||
# We simply route to the appropriate step based on whether we have
|
||||
# gradients from a previous step and whether we should be performing an
|
||||
# update or simply computing and accumulating gradients in this step.
|
||||
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
|
||||
if prev_grads is None:
|
||||
if perform_step:
|
||||
@@ -354,8 +348,7 @@ if __name__ == "__main__":
|
||||
dataset = FinetuningDataset(flux, args)
|
||||
dataset.encode_images()
|
||||
dataset.encode_prompts()
|
||||
dataset.generate_prior_preservation()
|
||||
guidance = mx.full((args.batch_size,), 4.0, dtype=flux.dtype)
|
||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
generate_progress_images(0, flux, args)
|
||||
@@ -382,8 +375,7 @@ if __name__ == "__main__":
|
||||
generate_progress_images(i + 1, flux, args)
|
||||
|
||||
if (i + 1) % args.checkpoint_every == 0:
|
||||
pass
|
||||
# save_checkpoints(i + 1, sd, args)
|
||||
save_adapters(i + 1, flux, args)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
|
@@ -36,6 +36,10 @@ class FluxPipeline:
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
|
Reference in New Issue
Block a user