mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Address comments and add some image augmentation
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,23 +9,11 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
from flux.lora import LoRALinear
|
||||
|
||||
|
||||
@contextmanager
|
||||
def random_state(seed=None):
|
||||
s = mx.random.state[0]
|
||||
try:
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
yield
|
||||
finally:
|
||||
mx.random.state[0] = s
|
||||
|
||||
|
||||
class FinetuningDataset:
|
||||
@@ -44,29 +31,60 @@ class FinetuningDataset:
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * b,
|
||||
pan[1] * c,
|
||||
crop_size[0] + pan[0] * b,
|
||||
crop_size[1] + pan[1] * c,
|
||||
)
|
||||
)
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def encode_images(self):
|
||||
"""Encode the images in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for sample in tqdm(self.index["data"]):
|
||||
img = Image.open(self.dataset_base / sample["image"])
|
||||
width, height = img.size
|
||||
if width != height:
|
||||
side = min(width, height)
|
||||
img = img.crop(
|
||||
(
|
||||
(width - side) / 2,
|
||||
(height - side) / 2,
|
||||
(width + side) / 2,
|
||||
(height + side) / 2,
|
||||
)
|
||||
)
|
||||
img = img.resize(self.args.resolution, Image.LANCZOS)
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
input_img = Image.open(self.dataset_base / sample["image"])
|
||||
for i in range(self.args.num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def encode_prompts(self):
|
||||
"""Pre-encode the prompts so that we don't recompute them during
|
||||
@@ -84,9 +102,14 @@ class FinetuningDataset:
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
indices = mx.random.randint(0, len(self.latents), (batch_size,))
|
||||
yield xs[indices], t5[indices], clip[indices]
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i : i + batch_size]
|
||||
c_i = c_indices[i : i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
@@ -156,7 +179,7 @@ if __name__ == "__main__":
|
||||
help="How many iterations to train for",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size to use when training the stable diffusion model",
|
||||
@@ -167,6 +190,12 @@ if __name__ == "__main__":
|
||||
default=(512, 512),
|
||||
help="The resolution of the training images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-augmentations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Augment the images by random cropping and panning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-prompt",
|
||||
required=True,
|
||||
@@ -219,17 +248,18 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the seed but different per worker if we are in a distributed
|
||||
# setting.
|
||||
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
||||
|
||||
# Load the model and set it up for LoRA training. We use the same random
|
||||
# state when creating the LoRA layers so all workers will have the same
|
||||
# initial weights.
|
||||
mx.random.seed(0x0F0F0F0F)
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
flux.flow.freeze()
|
||||
with random_state(0x0F0F0F0F):
|
||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
||||
|
||||
# Reset the seed to a different seed per worker if we are in distributed
|
||||
# mode so that each worker is working on different data, diffusion step and
|
||||
# random noise.
|
||||
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
||||
|
||||
# Report how many parameters we are training
|
||||
trainable_params = tree_reduce(
|
||||
|
@@ -43,14 +43,14 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
parser.add_argument("--n_images", type=int, default=4)
|
||||
parser.add_argument("--n-images", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--image_size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
||||
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
||||
)
|
||||
parser.add_argument("--steps", type=int)
|
||||
parser.add_argument("--guidance", type=float, default=4.0)
|
||||
parser.add_argument("--n_rows", type=int, default=1)
|
||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||
parser.add_argument("--n-rows", type=int, default=1)
|
||||
parser.add_argument("--decoding-batch-size", type=int, default=1)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--preload-models", action="store_true")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
|
Reference in New Issue
Block a user