Flux lora training

This commit is contained in:
Angelos Katharopoulos 2024-10-03 11:35:56 -07:00
parent 9eef46e645
commit 7cffcdcaff
6 changed files with 452 additions and 15 deletions

212
flux/dreambooth.py Normal file
View File

@ -0,0 +1,212 @@
import argparse
import time
from functools import partial
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_reduce, tree_unflatten
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
from flux.lora import LoRALinear
def linear_to_lora_layers(flux):
lora_layers = []
for name, mod in flux.flow.named_modules():
if ".img_attn" not in name:
continue
if ".qkv" in name or ".proj" in name:
lora_layers.append((name, LoRALinear.from_base(mod, r=32)))
flux.flow.update_modules(tree_unflatten(lora_layers))
def extract_latent_vectors(flux, image_folder):
flux.ae.eval()
latents = []
for image in tqdm(Path(image_folder).iterdir()):
img = Image.open(image)
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
x_0 = flux.ae.encode(img[None])
x_0 = x_0.astype(flux.dtype)
mx.eval(x_0)
latents.append(x_0)
return mx.concatenate(latents)
def decode_latents(flux, x):
decoded = []
for i in tqdm(range(len(x))):
decoded.append(flux.decode(x[i : i + 1]))
mx.eval(decoded[-1])
return mx.concatenate(decoded, axis=0)
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
latents = flux.generate_latents(
prompt,
n_images=n_images,
num_steps=steps,
seed=seed,
)
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
mx.eval(x_t)
return x_t
def generate_progress_images(iteration, flux, args):
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"out_{iteration:03d}.png"
print(f"Generating {str(out_file)}")
# Generate the latent vectors using diffusion
n_images = 4
latents = generate_latents(
flux,
n_images,
args.progress_prompt,
args.progress_steps,
seed=42,
)
# Arrange them on a grid
n_rows = 2
x = decode_latents(flux, latents)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
)
parser.add_argument(
"--model",
default="dev",
choices=[
"dev",
"schnell",
],
help="Which flux model to train",
)
parser.add_argument(
"--iterations",
type=int,
default=400,
help="How many iterations to train for",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size to use when training the stable diffusion model",
)
parser.add_argument(
"--progress-prompt",
help="Use this prompt when generating images for evaluation",
)
parser.add_argument(
"--progress-steps",
type=int,
default=50,
help="Use this many steps when generating images for evaluation",
)
parser.add_argument(
"--progress-every",
type=int,
default=50,
help="Generate images every PROGRESS_EVERY steps",
)
parser.add_argument(
"--checkpoint-every",
type=int,
default=50,
help="Save the model every CHECKPOINT_EVERY steps",
)
parser.add_argument(
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
)
parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
)
parser.add_argument("prompt")
parser.add_argument("image_folder")
args = parser.parse_args()
args.progress_prompt = args.progress_prompt or args.prompt
flux = FluxPipeline("flux-" + args.model)
flux.ensure_models_are_loaded()
flux.flow.freeze()
linear_to_lora_layers(flux)
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024**2:.3f}M parameters")
optimizer = optim.Adam(learning_rate=args.learning_rate)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(t5_tokens, clip_tokens, x, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
t5_tokens, clip_tokens, x, guidance
)
optimizer.update(flux.flow, grads)
return loss
print("Encoding training images to latent space")
x = extract_latent_vectors(flux, args.image_folder)
t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x))
guidance = mx.full((args.batch_size,), 4.0, dtype=flux.dtype)
# An initial generation to compare
generate_progress_images(0, flux, args)
losses = []
tic = time.time()
for i in range(args.iterations):
indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype(
mx.uint32
)
loss = step(t5_tokens[indices], clip_tokens[indices], x[indices], guidance)
mx.eval(loss, state)
losses.append(loss.item())
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB"
)
if (i + 1) % args.progress_every == 0:
generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0:
pass
# save_checkpoints(i + 1, sd, args)
if (i + 1) % 10 == 0:
losses = []
tic = time.time()

View File

@ -36,6 +36,11 @@ class FluxPipeline:
self.t5.parameters(),
)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
@ -56,16 +61,14 @@ class FluxPipeline:
return x, x_ids
def _prepare_conditioning(self, n_images, text):
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
t5_tokens = self.t5_tokenizer.encode(text)
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
clip_tokens = self.clip_tokenizer.encode(text)
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
@ -131,8 +134,13 @@ class FluxPipeline:
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
txt, txt_ids, vec = self._prepare_conditioning(n_images, text)
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
@ -142,6 +150,38 @@ class FluxPipeline:
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
x = (mx.clip(x + 1, 0, 2) * 127.5).astype(mx.uint8)
return mx.clip(x + 1, 0, 2) * 0.5
return x
def training_loss(
self,
t5_tokens: mx.array,
clip_tokens: mx.array,
x_0: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = self.t5(t5_tokens)
txt_ids = mx.zeros(t5_tokens.shape + (3,), dtype=mx.int32)
vec = self.clip(clip_tokens).pooled_output
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process (we use rf/lognorm(0, 1))
t = mx.sigmoid(mx.random.normal(shape=(len(x_0),), dtype=self.dtype))
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred - (eps - x_0)).square().mean()

View File

@ -296,14 +296,9 @@ class Decoder(nn.Module):
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def __call__(self, z: mx.array):
mean, logvar = mx.split(z, 2, axis=self.chunk_dim)
if self.sample:
mean, logvar = mx.split(z, 2, axis=-1)
if self.training:
std = mx.exp(0.5 * logvar)
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
return mean + std * eps

74
flux/flux/lora.py Normal file
View File

@ -0,0 +1,74 @@
import math
import mlx.core as mx
import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
output_dims, input_dims = linear.weight.shape
lora_lin = LoRALinear(
input_dims=input_dims,
output_dims=output_dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
return y + (self.scale * z).astype(x.dtype)

116
flux/txt2image.py Normal file
View File

@ -0,0 +1,116 @@
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--n_images", type=int, default=4)
parser.add_argument(
"--image_size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
)
parser.add_argument("--steps", type=int)
parser.add_argument("--guidance", type=float, default=4.0)
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--seed", type=int)
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
# Load the models
flux = FluxPipeline("flux-" + args.model)
args.steps = args.steps or (50 if args.model == "dev" else 2)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
if args.preload_models:
sd.ensure_models_are_loaded()
# Make the generator
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=args.n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
seed=args.seed,
)
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
del flux.t5
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the text: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")

View File

@ -151,7 +151,7 @@ if __name__ == "__main__":
help="Save the model every CHECKPOINT_EVERY steps",
)
parser.add_argument(
"--learning_rate", type=float, default="1e-6", help="Learning rate for training"
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
)
parser.add_argument(
"--predict-x0",
@ -201,7 +201,7 @@ if __name__ == "__main__":
sd = StableDiffusion(args.model)
sd.ensure_models_are_loaded()
optimizer = optim.Adam(learning_rate=1e-6)
optimizer = optim.Adam(learning_rate=args.learning_rate)
def loss_fn(params, text, x, weights):
sd.unet.update(params["unet"])