mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-11 06:04:36 +08:00
Address comments and add some image augmentation
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -10,23 +9,11 @@ import mlx.nn as nn
|
|||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from flux import FluxPipeline
|
from flux import FluxPipeline
|
||||||
from flux.lora import LoRALinear
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def random_state(seed=None):
|
|
||||||
s = mx.random.state[0]
|
|
||||||
try:
|
|
||||||
if seed is not None:
|
|
||||||
mx.random.seed(seed)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
mx.random.state[0] = s
|
|
||||||
|
|
||||||
|
|
||||||
class FinetuningDataset:
|
class FinetuningDataset:
|
||||||
@@ -44,27 +31,58 @@ class FinetuningDataset:
|
|||||||
self.t5_features = []
|
self.t5_features = []
|
||||||
self.clip_features = []
|
self.clip_features = []
|
||||||
|
|
||||||
|
def _random_crop_resize(self, img):
|
||||||
|
resolution = self.args.resolution
|
||||||
|
width, height = img.size
|
||||||
|
|
||||||
|
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||||
|
|
||||||
|
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||||
|
crop_size = (
|
||||||
|
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||||
|
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||||
|
)
|
||||||
|
pan = (width - crop_size[0], height - crop_size[1])
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
pan[0] * b,
|
||||||
|
pan[1] * c,
|
||||||
|
crop_size[0] + pan[0] * b,
|
||||||
|
crop_size[1] + pan[1] * c,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fit the largest rectangle with the ratio of resolution in the image
|
||||||
|
# rectangle.
|
||||||
|
width, height = crop_size
|
||||||
|
ratio = resolution[0] / resolution[1]
|
||||||
|
r1 = (height * ratio, height)
|
||||||
|
r2 = (width, width / ratio)
|
||||||
|
r = r1 if r1[0] <= width else r2
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
(width - r[0]) / 2,
|
||||||
|
(height - r[1]) / 2,
|
||||||
|
(width + r[0]) / 2,
|
||||||
|
(height + r[1]) / 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally resize the image to resolution
|
||||||
|
img = img.resize(resolution, Image.LANCZOS)
|
||||||
|
|
||||||
|
return mx.array(np.array(img))
|
||||||
|
|
||||||
def encode_images(self):
|
def encode_images(self):
|
||||||
"""Encode the images in the latent space to prepare for training."""
|
"""Encode the images in the latent space to prepare for training."""
|
||||||
self.flux.ae.eval()
|
self.flux.ae.eval()
|
||||||
for sample in tqdm(self.index["data"]):
|
for sample in tqdm(self.index["data"]):
|
||||||
img = Image.open(self.dataset_base / sample["image"])
|
input_img = Image.open(self.dataset_base / sample["image"])
|
||||||
width, height = img.size
|
for i in range(self.args.num_augmentations):
|
||||||
if width != height:
|
img = self._random_crop_resize(input_img)
|
||||||
side = min(width, height)
|
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||||
img = img.crop(
|
|
||||||
(
|
|
||||||
(width - side) / 2,
|
|
||||||
(height - side) / 2,
|
|
||||||
(width + side) / 2,
|
|
||||||
(height + side) / 2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
img = img.resize(self.args.resolution, Image.LANCZOS)
|
|
||||||
img = mx.array(np.array(img))
|
|
||||||
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
|
|
||||||
x_0 = self.flux.ae.encode(img[None])
|
x_0 = self.flux.ae.encode(img[None])
|
||||||
x_0 = x_0.astype(flux.dtype)
|
x_0 = x_0.astype(self.flux.dtype)
|
||||||
mx.eval(x_0)
|
mx.eval(x_0)
|
||||||
self.latents.append(x_0)
|
self.latents.append(x_0)
|
||||||
|
|
||||||
@@ -84,9 +102,14 @@ class FinetuningDataset:
|
|||||||
t5 = mx.concatenate(self.t5_features)
|
t5 = mx.concatenate(self.t5_features)
|
||||||
clip = mx.concatenate(self.clip_features)
|
clip = mx.concatenate(self.clip_features)
|
||||||
mx.eval(xs, t5, clip)
|
mx.eval(xs, t5, clip)
|
||||||
|
n_aug = self.args.num_augmentations
|
||||||
while True:
|
while True:
|
||||||
indices = mx.random.randint(0, len(self.latents), (batch_size,))
|
x_indices = mx.random.permutation(len(self.latents))
|
||||||
yield xs[indices], t5[indices], clip[indices]
|
c_indices = x_indices // n_aug
|
||||||
|
for i in range(0, len(self.latents), batch_size):
|
||||||
|
x_i = x_indices[i : i + batch_size]
|
||||||
|
c_i = c_indices[i : i + batch_size]
|
||||||
|
yield xs[x_i], t5[c_i], clip[c_i]
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
@@ -156,7 +179,7 @@ if __name__ == "__main__":
|
|||||||
help="How many iterations to train for",
|
help="How many iterations to train for",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_size",
|
"--batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The batch size to use when training the stable diffusion model",
|
help="The batch size to use when training the stable diffusion model",
|
||||||
@@ -167,6 +190,12 @@ if __name__ == "__main__":
|
|||||||
default=(512, 512),
|
default=(512, 512),
|
||||||
help="The resolution of the training images",
|
help="The resolution of the training images",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-augmentations",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Augment the images by random cropping and panning",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--progress-prompt",
|
"--progress-prompt",
|
||||||
required=True,
|
required=True,
|
||||||
@@ -219,18 +248,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Initialize the seed but different per worker if we are in a distributed
|
|
||||||
# setting.
|
|
||||||
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
|
||||||
|
|
||||||
# Load the model and set it up for LoRA training. We use the same random
|
# Load the model and set it up for LoRA training. We use the same random
|
||||||
# state when creating the LoRA layers so all workers will have the same
|
# state when creating the LoRA layers so all workers will have the same
|
||||||
# initial weights.
|
# initial weights.
|
||||||
|
mx.random.seed(0x0F0F0F0F)
|
||||||
flux = FluxPipeline("flux-" + args.model)
|
flux = FluxPipeline("flux-" + args.model)
|
||||||
flux.flow.freeze()
|
flux.flow.freeze()
|
||||||
with random_state(0x0F0F0F0F):
|
|
||||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
||||||
|
|
||||||
|
# Reset the seed to a different seed per worker if we are in distributed
|
||||||
|
# mode so that each worker is working on different data, diffusion step and
|
||||||
|
# random noise.
|
||||||
|
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
||||||
|
|
||||||
# Report how many parameters we are training
|
# Report how many parameters we are training
|
||||||
trainable_params = tree_reduce(
|
trainable_params = tree_reduce(
|
||||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||||
|
@@ -43,14 +43,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument("prompt")
|
parser.add_argument("prompt")
|
||||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||||
parser.add_argument("--n_images", type=int, default=4)
|
parser.add_argument("--n-images", type=int, default=4)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image_size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
||||||
)
|
)
|
||||||
parser.add_argument("--steps", type=int)
|
parser.add_argument("--steps", type=int)
|
||||||
parser.add_argument("--guidance", type=float, default=4.0)
|
parser.add_argument("--guidance", type=float, default=4.0)
|
||||||
parser.add_argument("--n_rows", type=int, default=1)
|
parser.add_argument("--n-rows", type=int, default=1)
|
||||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
parser.add_argument("--decoding-batch-size", type=int, default=1)
|
||||||
parser.add_argument("--quantize", "-q", action="store_true")
|
parser.add_argument("--quantize", "-q", action="store_true")
|
||||||
parser.add_argument("--preload-models", action="store_true")
|
parser.add_argument("--preload-models", action="store_true")
|
||||||
parser.add_argument("--output", default="out.png")
|
parser.add_argument("--output", default="out.png")
|
||||||
|
Reference in New Issue
Block a user