mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Flux lora training
This commit is contained in:
parent
9eef46e645
commit
7cffcdcaff
212
flux/dreambooth.py
Normal file
212
flux/dreambooth.py
Normal file
@ -0,0 +1,212 @@
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_reduce, tree_unflatten
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
from flux.lora import LoRALinear
|
||||
|
||||
|
||||
def linear_to_lora_layers(flux):
|
||||
lora_layers = []
|
||||
for name, mod in flux.flow.named_modules():
|
||||
if ".img_attn" not in name:
|
||||
continue
|
||||
if ".qkv" in name or ".proj" in name:
|
||||
lora_layers.append((name, LoRALinear.from_base(mod, r=32)))
|
||||
flux.flow.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
|
||||
def extract_latent_vectors(flux, image_folder):
|
||||
flux.ae.eval()
|
||||
latents = []
|
||||
for image in tqdm(Path(image_folder).iterdir()):
|
||||
img = Image.open(image)
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
|
||||
x_0 = flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(flux.dtype)
|
||||
mx.eval(x_0)
|
||||
latents.append(x_0)
|
||||
return mx.concatenate(latents)
|
||||
|
||||
|
||||
def decode_latents(flux, x):
|
||||
decoded = []
|
||||
for i in tqdm(range(len(x))):
|
||||
decoded.append(flux.decode(x[i : i + 1]))
|
||||
mx.eval(decoded[-1])
|
||||
return mx.concatenate(decoded, axis=0)
|
||||
|
||||
|
||||
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=n_images,
|
||||
num_steps=steps,
|
||||
seed=seed,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
|
||||
mx.eval(x_t)
|
||||
|
||||
return x_t
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"out_{iteration:03d}.png"
|
||||
print(f"Generating {str(out_file)}")
|
||||
# Generate the latent vectors using diffusion
|
||||
n_images = 4
|
||||
latents = generate_latents(
|
||||
flux,
|
||||
n_images,
|
||||
args.progress_prompt,
|
||||
args.progress_steps,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Arrange them on a grid
|
||||
n_rows = 2
|
||||
x = decode_latents(flux, latents)
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(out_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Finetune Flux to generate images with a specific subject"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="dev",
|
||||
choices=[
|
||||
"dev",
|
||||
"schnell",
|
||||
],
|
||||
help="Which flux model to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=400,
|
||||
help="How many iterations to train for",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size to use when training the stable diffusion model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-prompt",
|
||||
help="Use this prompt when generating images for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Use this many steps when generating images for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-every",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Generate images every PROGRESS_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-every",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Save the model every CHECKPOINT_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
||||
)
|
||||
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("image_folder")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.progress_prompt = args.progress_prompt or args.prompt
|
||||
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
flux.ensure_models_are_loaded()
|
||||
flux.flow.freeze()
|
||||
linear_to_lora_layers(flux)
|
||||
|
||||
trainable_params = tree_reduce(
|
||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||
)
|
||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters")
|
||||
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(t5_tokens, clip_tokens, x, guidance):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
)
|
||||
optimizer.update(flux.flow, grads)
|
||||
|
||||
return loss
|
||||
|
||||
print("Encoding training images to latent space")
|
||||
x = extract_latent_vectors(flux, args.image_folder)
|
||||
t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x))
|
||||
guidance = mx.full((args.batch_size,), 4.0, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
generate_progress_images(0, flux, args)
|
||||
|
||||
losses = []
|
||||
tic = time.time()
|
||||
for i in range(args.iterations):
|
||||
indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype(
|
||||
mx.uint32
|
||||
)
|
||||
loss = step(t5_tokens[indices], clip_tokens[indices], x[indices], guidance)
|
||||
mx.eval(loss, state)
|
||||
losses.append(loss.item())
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
toc = time.time()
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||
print(
|
||||
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
f"Peak mem: {peak_mem:.3f} GB"
|
||||
)
|
||||
|
||||
if (i + 1) % args.progress_every == 0:
|
||||
generate_progress_images(i + 1, flux, args)
|
||||
|
||||
if (i + 1) % args.checkpoint_every == 0:
|
||||
pass
|
||||
# save_checkpoints(i + 1, sd, args)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
tic = time.time()
|
@ -36,6 +36,11 @@ class FluxPipeline:
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
return t5_tokens, clip_tokens
|
||||
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
@ -56,16 +61,14 @@ class FluxPipeline:
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, text):
|
||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||
# Prepare the text features
|
||||
t5_tokens = self.t5_tokenizer.encode(text)
|
||||
txt = self.t5(t5_tokens)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
@ -131,8 +134,13 @@ class FluxPipeline:
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the conditioning
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, text)
|
||||
t5_tokens, clip_tokens = self.tokenize(text)
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||
|
||||
# Yield the conditioning for controlled evaluation by the caller
|
||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||
|
||||
# Yield the latent sequences from the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
@ -142,6 +150,38 @@ class FluxPipeline:
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
x = (mx.clip(x + 1, 0, 2) * 127.5).astype(mx.uint8)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
return x
|
||||
def training_loss(
|
||||
self,
|
||||
t5_tokens: mx.array,
|
||||
clip_tokens: mx.array,
|
||||
x_0: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = self.t5(t5_tokens)
|
||||
txt_ids = mx.zeros(t5_tokens.shape + (3,), dtype=mx.int32)
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process (we use rf/lognorm(0, 1))
|
||||
t = mx.sigmoid(mx.random.normal(shape=(len(x_0),), dtype=self.dtype))
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return (pred - (eps - x_0)).square().mean()
|
||||
|
@ -296,14 +296,9 @@ class Decoder(nn.Module):
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def __call__(self, z: mx.array):
|
||||
mean, logvar = mx.split(z, 2, axis=self.chunk_dim)
|
||||
if self.sample:
|
||||
mean, logvar = mx.split(z, 2, axis=-1)
|
||||
if self.training:
|
||||
std = mx.exp(0.5 * logvar)
|
||||
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
|
||||
return mean + std * eps
|
||||
|
74
flux/flux/lora.py
Normal file
74
flux/flux/lora.py
Normal file
@ -0,0 +1,74 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_base(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
dtype = weight.dtype
|
||||
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
fused_linear.weight = weight + lora_b @ lora_a
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.linear(x)
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
return y + (self.scale * z).astype(x.dtype)
|
116
flux/txt2image.py
Normal file
116
flux/txt2image.py
Normal file
@ -0,0 +1,116 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
def to_latent_size(image_size):
|
||||
h, w = image_size
|
||||
h = ((h + 15) // 16) * 16
|
||||
w = ((w + 15) // 16) * 16
|
||||
|
||||
if (h, w) != image_size:
|
||||
print(
|
||||
"Warning: The image dimensions need to be divisible by 16px. "
|
||||
f"Changing size to {h}x{w}."
|
||||
)
|
||||
|
||||
return (h // 8, w // 8)
|
||||
|
||||
|
||||
def quantization_predicate(name, m):
|
||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
parser.add_argument("--n_images", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--image_size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
||||
)
|
||||
parser.add_argument("--steps", type=int)
|
||||
parser.add_argument("--guidance", type=float, default=4.0)
|
||||
parser.add_argument("--n_rows", type=int, default=1)
|
||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--preload-models", action="store_true")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
parser.add_argument("--seed", type=int)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the models
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
args.steps = args.steps or (50 if args.model == "dev" else 2)
|
||||
|
||||
if args.quantize:
|
||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
|
||||
# Make the generator
|
||||
latent_size = to_latent_size(args.image_size)
|
||||
latents = flux.generate_latents(
|
||||
args.prompt,
|
||||
n_images=args.n_images,
|
||||
num_steps=args.steps,
|
||||
latent_size=latent_size,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# First we get and eval the conditioning
|
||||
conditioning = next(latents)
|
||||
mx.eval(conditioning)
|
||||
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the text encoders.
|
||||
del flux.t5
|
||||
del flux.clip
|
||||
|
||||
# Actual denoising loop
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
mx.eval(x_t)
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the flow transformer.
|
||||
del flux.flow
|
||||
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
||||
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(args.output)
|
||||
|
||||
# Report the peak memory used during generation
|
||||
if args.verbose:
|
||||
print(f"Peak memory used for the text: {peak_mem_generation:.3f}GB")
|
||||
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
@ -151,7 +151,7 @@ if __name__ == "__main__":
|
||||
help="Save the model every CHECKPOINT_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default="1e-6", help="Learning rate for training"
|
||||
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict-x0",
|
||||
@ -201,7 +201,7 @@ if __name__ == "__main__":
|
||||
sd = StableDiffusion(args.model)
|
||||
sd.ensure_models_are_loaded()
|
||||
|
||||
optimizer = optim.Adam(learning_rate=1e-6)
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
|
||||
def loss_fn(params, text, x, weights):
|
||||
sd.unet.update(params["unet"])
|
||||
|
Loading…
Reference in New Issue
Block a user