Address comments and add some image augmentation

This commit is contained in:
Angelos Katharopoulos
2024-10-10 17:58:01 -07:00
parent 8c3b25f88c
commit f7749ab043
2 changed files with 76 additions and 46 deletions

View File

@@ -1,7 +1,6 @@
import argparse import argparse
import json import json
import time import time
from contextlib import contextmanager
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@@ -10,23 +9,11 @@ import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from flux import FluxPipeline from flux import FluxPipeline
from flux.lora import LoRALinear
@contextmanager
def random_state(seed=None):
s = mx.random.state[0]
try:
if seed is not None:
mx.random.seed(seed)
yield
finally:
mx.random.state[0] = s
class FinetuningDataset: class FinetuningDataset:
@@ -44,27 +31,58 @@ class FinetuningDataset:
self.t5_features = [] self.t5_features = []
self.clip_features = [] self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def encode_images(self): def encode_images(self):
"""Encode the images in the latent space to prepare for training.""" """Encode the images in the latent space to prepare for training."""
self.flux.ae.eval() self.flux.ae.eval()
for sample in tqdm(self.index["data"]): for sample in tqdm(self.index["data"]):
img = Image.open(self.dataset_base / sample["image"]) input_img = Image.open(self.dataset_base / sample["image"])
width, height = img.size for i in range(self.args.num_augmentations):
if width != height: img = self._random_crop_resize(input_img)
side = min(width, height) img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
img = img.crop(
(
(width - side) / 2,
(height - side) / 2,
(width + side) / 2,
(height + side) / 2,
)
)
img = img.resize(self.args.resolution, Image.LANCZOS)
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None]) x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(flux.dtype) x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0) mx.eval(x_0)
self.latents.append(x_0) self.latents.append(x_0)
@@ -84,9 +102,14 @@ class FinetuningDataset:
t5 = mx.concatenate(self.t5_features) t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features) clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip) mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True: while True:
indices = mx.random.randint(0, len(self.latents), (batch_size,)) x_indices = mx.random.permutation(len(self.latents))
yield xs[indices], t5[indices], clip[indices] c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@@ -156,7 +179,7 @@ if __name__ == "__main__":
help="How many iterations to train for", help="How many iterations to train for",
) )
parser.add_argument( parser.add_argument(
"--batch_size", "--batch-size",
type=int, type=int,
default=1, default=1,
help="The batch size to use when training the stable diffusion model", help="The batch size to use when training the stable diffusion model",
@@ -167,6 +190,12 @@ if __name__ == "__main__":
default=(512, 512), default=(512, 512),
help="The resolution of the training images", help="The resolution of the training images",
) )
parser.add_argument(
"--num-augmentations",
type=int,
default=5,
help="Augment the images by random cropping and panning",
)
parser.add_argument( parser.add_argument(
"--progress-prompt", "--progress-prompt",
required=True, required=True,
@@ -219,18 +248,19 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Initialize the seed but different per worker if we are in a distributed
# setting.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Load the model and set it up for LoRA training. We use the same random # Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same # state when creating the LoRA layers so all workers will have the same
# initial weights. # initial weights.
mx.random.seed(0x0F0F0F0F)
flux = FluxPipeline("flux-" + args.model) flux = FluxPipeline("flux-" + args.model)
flux.flow.freeze() flux.flow.freeze()
with random_state(0x0F0F0F0F):
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks) flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
# Reset the seed to a different seed per worker if we are in distributed
# mode so that each worker is working on different data, diffusion step and
# random noise.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Report how many parameters we are training # Report how many parameters we are training
trainable_params = tree_reduce( trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0

View File

@@ -43,14 +43,14 @@ if __name__ == "__main__":
) )
parser.add_argument("prompt") parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--n_images", type=int, default=4) parser.add_argument("--n-images", type=int, default=4)
parser.add_argument( parser.add_argument(
"--image_size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512) "--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
) )
parser.add_argument("--steps", type=int) parser.add_argument("--steps", type=int)
parser.add_argument("--guidance", type=float, default=4.0) parser.add_argument("--guidance", type=float, default=4.0)
parser.add_argument("--n_rows", type=int, default=1) parser.add_argument("--n-rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1) parser.add_argument("--decoding-batch-size", type=int, default=1)
parser.add_argument("--quantize", "-q", action="store_true") parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true") parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png") parser.add_argument("--output", default="out.png")