mlx-examples/flux/dreambooth.py

518 lines
16 KiB
Python
Raw Normal View History

2024-10-12 12:17:41 +08:00
# Copyright © 2024 Apple Inc.
import argparse
import time
from functools import partial
from pathlib import Path
2024-10-14 22:59:56 +08:00
import os
2024-10-12 12:17:41 +08:00
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
2024-10-14 22:59:56 +08:00
from huggingface_hub import HfApi, interpreter_login
from huggingface_hub.utils import HfFolder
2024-11-09 09:15:19 +08:00
from flux import FluxPipeline, Trainer, load_dataset, save_config
2024-10-12 12:17:41 +08:00
class FinetuningDataset:
def __init__(self, flux, args):
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
2024-11-09 09:15:19 +08:00
def save_adapters(adapter_name, flux, args):
2024-10-12 12:17:41 +08:00
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
2024-11-09 09:15:19 +08:00
out_file = out_dir / adapter_name
2024-10-12 12:17:41 +08:00
print(f"Saving {str(out_file)}")
mx.save_safetensors(
str(out_file),
dict(tree_flatten(flux.flow.trainable_parameters())),
metadata={
"lora_rank": str(args.lora_rank),
"lora_blocks": str(args.lora_blocks),
},
)
2024-10-14 22:59:56 +08:00
def push_to_hub(args):
if args.hf_token is None:
interpreter_login(new_session=False, write_permission=True)
else:
HfFolder.save_token(args.hf_token)
repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}"
readme_content = generate_readme(args, repo_id)
readme_path = os.path.join(args.output_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
api = HfApi()
api.create_repo(
repo_id,
private=args.hf_private,
exist_ok=True
)
api.upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
ignore_patterns=["*.yaml", "*.pt"],
repo_type="model",
)
def generate_readme(args, repo_id):
import yaml
import re
base_model = f"flux-{args.model}"
tags = [
"text-to-image",
"flux",
"lora",
"diffusers",
"template:sd-lora",
"mlx",
"mlx-trainer"
]
widgets = []
sample_image_paths = []
# Look for progress images directly in the output directory
for filename in os.listdir(args.output_dir):
match = re.search(r"(\d+)_progress\.png$", filename)
if match:
iteration = int(match.group(1))
sample_image_paths.append((iteration, filename))
sample_image_paths.sort(key=lambda x: x[0], reverse=True)
if sample_image_paths:
widgets.append(
{
"text": args.progress_prompt,
"output": {
"url": sample_image_paths[0][1]
},
}
)
readme_content = f"""---
tags:
{yaml.dump(tags, indent=4).strip()}
{"widget:" if sample_image_paths else ""}
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
base_model: {base_model}
license: other
---
# {os.path.basename(args.output_dir)}
Model trained with the MLX Flux Dreambooth script
<Gallery />
## Use it with [MLX](https://github.com/ml-explore/mlx-examples)
```py
from flux import FluxPipeline
import mlx.core as mx
flux = FluxPipeline("flux-{args.model}")
flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks})
flux.flow.load_weights("{repo_id}")
image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps})
image.save("my_image.png")
```
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('{repo_id}')
image = pipeline({args.progress_prompt}').images[0]
image.save("my_image.png")
```
For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux).
"""
return readme_content
2024-10-12 12:17:41 +08:00
def setup_arg_parser():
"""Set up and return the argument parser."""
2024-10-12 12:17:41 +08:00
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
)
parser.add_argument(
"--model",
default="dev",
choices=[
"dev",
"schnell",
],
help="Which flux model to train",
)
parser.add_argument(
"--guidance", type=float, default=4.0, help="The guidance factor to use."
)
parser.add_argument(
"--iterations",
type=int,
default=600,
help="How many iterations to train for",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="The batch size to use when training the stable diffusion model",
)
parser.add_argument(
"--resolution",
type=lambda x: tuple(map(int, x.split("x"))),
default=(512, 512),
help="The resolution of the training images",
)
parser.add_argument(
"--num-augmentations",
type=int,
default=5,
help="Augment the images by random cropping and panning",
)
parser.add_argument(
"--progress-prompt",
required=True,
help="Use this prompt when generating images for evaluation",
)
parser.add_argument(
"--progress-steps",
type=int,
default=50,
help="Use this many steps when generating images for evaluation",
)
parser.add_argument(
"--progress-every",
type=int,
default=50,
help="Generate images every PROGRESS_EVERY steps",
)
parser.add_argument(
"--checkpoint-every",
type=int,
default=50,
help="Save the model every CHECKPOINT_EVERY steps",
)
parser.add_argument(
"--lora-blocks",
type=int,
default=-1,
help="Train the last LORA_BLOCKS transformer blocks",
)
parser.add_argument(
"--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
)
parser.add_argument(
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
)
parser.add_argument(
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
)
parser.add_argument(
"--grad-accumulate",
type=int,
default=4,
help="Accumulate gradients for that many iterations before applying them",
)
parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
)
2024-10-14 22:59:56 +08:00
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push the model to Hugging Face Hub after training",
)
parser.add_argument(
"--hf_token",
type=str,
default=None,
help="Hugging Face token for pushing to Hub",
)
parser.add_argument(
"--hf_repo_id",
type=str,
default=None,
help="Hugging Face repository ID for pushing to Hub",
)
parser.add_argument(
"--hf_private",
action="store_true",
help="Make the Hugging Face repository private",
)
2024-10-12 12:17:41 +08:00
parser.add_argument("dataset")
return parser
2024-10-12 12:17:41 +08:00
if __name__ == "__main__":
parser = setup_arg_parser()
2024-10-12 12:17:41 +08:00
args = parser.parse_args()
2024-11-09 09:15:19 +08:00
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), output_path / "adapter_config.json")
2024-10-12 12:17:41 +08:00
# Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same
# initial weights.
mx.random.seed(0x0F0F0F0F)
flux = FluxPipeline("flux-" + args.model)
flux.flow.freeze()
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
# Reset the seed to a different seed per worker if we are in distributed
# mode so that each worker is working on different data, diffusion step and
# random noise.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Report how many parameters we are training
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
2024-10-12 12:17:41 +08:00
# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
cosine = optim.cosine_decay(
args.learning_rate, args.iterations // args.grad_accumulate
)
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(
lambda a, b: (a + b) / args.grad_accumulate,
prev_grads,
grads,
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
if prev_grads is None:
if perform_step:
return single_step(x, t5_feat, clip_feat, guidance), None
else:
return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
else:
if perform_step:
return (
grad_accumulate_and_step(
x, t5_feat, clip_feat, guidance, prev_grads
),
None,
)
else:
return compute_loss_and_accumulate_grads(
x, t5_feat, clip_feat, guidance, prev_grads
)
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()
2024-10-12 12:17:41 +08:00
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
generate_progress_images(0, flux, args)
grads = None
losses = []
tic = time.time()
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
2024-10-12 12:17:41 +08:00
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
2024-10-12 12:17:41 +08:00
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,
)
if (i + 1) % args.progress_every == 0:
generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0:
2024-11-09 09:15:19 +08:00
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
2024-10-12 12:17:41 +08:00
if (i + 1) % 10 == 0:
losses = []
tic = time.time()
2024-10-14 22:59:56 +08:00
if args.push_to_hub:
push_to_hub(args)
2024-11-09 09:15:19 +08:00
save_adapters("final_adapters.safetensors", flux, args)
print("Training successful.")