mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
FLUX: Optimize dataset loading logic (#1038)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -13,105 +12,8 @@ import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
class FinetuningDataset:
|
||||
def __init__(self, flux, args):
|
||||
self.args = args
|
||||
self.flux = flux
|
||||
self.dataset_base = Path(args.dataset)
|
||||
dataset_index = self.dataset_base / "index.json"
|
||||
if not dataset_index.exists():
|
||||
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
|
||||
with open(dataset_index, "r") as f:
|
||||
self.index = json.load(f)
|
||||
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * b,
|
||||
pan[1] * c,
|
||||
crop_size[0] + pan[0] * b,
|
||||
crop_size[1] + pan[1] * c,
|
||||
)
|
||||
)
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def encode_images(self):
|
||||
"""Encode the images in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for sample in tqdm(self.index["data"]):
|
||||
input_img = Image.open(self.dataset_base / sample["image"])
|
||||
for i in range(self.args.num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def encode_prompts(self):
|
||||
"""Pre-encode the prompts so that we don't recompute them during
|
||||
training (doesn't allow finetuning the text encoders)."""
|
||||
for sample in tqdm(self.index["data"]):
|
||||
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i : i + batch_size]
|
||||
c_i = c_indices[i : i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
||||
from flux import FluxPipeline, Trainer, load_dataset
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
@@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Finetune Flux to generate images with a specific subject"
|
||||
)
|
||||
@@ -247,7 +150,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument("dataset")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the model and set it up for LoRA training. We use the same random
|
||||
@@ -267,7 +174,7 @@ if __name__ == "__main__":
|
||||
trainable_params = tree_reduce(
|
||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||
)
|
||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
|
||||
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
|
||||
|
||||
# Set up the optimizer and training steps. The steps are a bit verbose to
|
||||
# support gradient accumulation together with compilation.
|
||||
@@ -340,10 +247,10 @@ if __name__ == "__main__":
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
)
|
||||
|
||||
print("Create the training dataset.", flush=True)
|
||||
dataset = FinetuningDataset(flux, args)
|
||||
dataset.encode_images()
|
||||
dataset.encode_prompts()
|
||||
dataset = load_dataset(args.dataset)
|
||||
trainer = Trainer(flux, dataset, args)
|
||||
trainer.encode_dataset()
|
||||
|
||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
@@ -352,7 +259,7 @@ if __name__ == "__main__":
|
||||
grads = None
|
||||
losses = []
|
||||
tic = time.time()
|
||||
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
|
||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||
mx.eval(loss, grads, state)
|
||||
losses.append(loss.item())
|
||||
@@ -361,7 +268,7 @@ if __name__ == "__main__":
|
||||
toc = time.time()
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||
print(
|
||||
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
f"Peak mem: {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
|
Reference in New Issue
Block a user