Address comments and add some image augmentation

This commit is contained in:
Angelos Katharopoulos
2024-10-10 17:58:01 -07:00
parent 8c3b25f88c
commit f7749ab043
2 changed files with 76 additions and 46 deletions

View File

@@ -1,7 +1,6 @@
import argparse
import json
import time
from contextlib import contextmanager
from functools import partial
from pathlib import Path
@@ -10,23 +9,11 @@ import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
from flux.lora import LoRALinear
@contextmanager
def random_state(seed=None):
s = mx.random.state[0]
try:
if seed is not None:
mx.random.seed(seed)
yield
finally:
mx.random.state[0] = s
class FinetuningDataset:
@@ -44,29 +31,60 @@ class FinetuningDataset:
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
img = Image.open(self.dataset_base / sample["image"])
width, height = img.size
if width != height:
side = min(width, height)
img = img.crop(
(
(width - side) / 2,
(height - side) / 2,
(width + side) / 2,
(height + side) / 2,
)
)
img = img.resize(self.args.resolution, Image.LANCZOS)
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
@@ -84,9 +102,14 @@ class FinetuningDataset:
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
indices = mx.random.randint(0, len(self.latents), (batch_size,))
yield xs[indices], t5[indices], clip[indices]
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
def generate_progress_images(iteration, flux, args):
@@ -156,7 +179,7 @@ if __name__ == "__main__":
help="How many iterations to train for",
)
parser.add_argument(
"--batch_size",
"--batch-size",
type=int,
default=1,
help="The batch size to use when training the stable diffusion model",
@@ -167,6 +190,12 @@ if __name__ == "__main__":
default=(512, 512),
help="The resolution of the training images",
)
parser.add_argument(
"--num-augmentations",
type=int,
default=5,
help="Augment the images by random cropping and panning",
)
parser.add_argument(
"--progress-prompt",
required=True,
@@ -219,17 +248,18 @@ if __name__ == "__main__":
args = parser.parse_args()
# Initialize the seed but different per worker if we are in a distributed
# setting.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same
# initial weights.
mx.random.seed(0x0F0F0F0F)
flux = FluxPipeline("flux-" + args.model)
flux.flow.freeze()
with random_state(0x0F0F0F0F):
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
# Reset the seed to a different seed per worker if we are in distributed
# mode so that each worker is working on different data, diffusion step and
# random noise.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Report how many parameters we are training
trainable_params = tree_reduce(

View File

@@ -43,14 +43,14 @@ if __name__ == "__main__":
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--n_images", type=int, default=4)
parser.add_argument("--n-images", type=int, default=4)
parser.add_argument(
"--image_size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
)
parser.add_argument("--steps", type=int)
parser.add_argument("--guidance", type=float, default=4.0)
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
parser.add_argument("--n-rows", type=int, default=1)
parser.add_argument("--decoding-batch-size", type=int, default=1)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")