FLUX: support huggingface dataset

This commit is contained in:
madroid 2024-10-13 01:38:58 +08:00
parent ca88343118
commit b0de67ec03
3 changed files with 162 additions and 98 deletions

View File

@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import json
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -10,12 +9,11 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
from PIL import Image
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline, load_dataset from flux import FluxPipeline, load_dataset, Trainer
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@ -69,7 +67,7 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--model", "--model",
default="dev", default="schnell",
choices=[ choices=[
"dev", "dev",
"schnell", "schnell",
@ -188,6 +186,7 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule) optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state] state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance): def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -198,12 +197,14 @@ if __name__ == "__main__":
return loss return loss
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)( return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance x, t5_feat, clip_feat, guidance
) )
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -212,6 +213,7 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads) grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads return loss, grads
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -227,6 +229,7 @@ if __name__ == "__main__":
return loss return loss
# We simply route to the appropriate step based on whether we have # We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an # gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step. # update or simply computing and accumulating gradients in this step.
@ -249,10 +252,12 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads x, t5_feat, clip_feat, guidance, prev_grads
) )
print("Create the training dataset.", flush=True)
# print("Create the training dataset.", flush=True)
dataset = load_dataset(flux, args) dataset = load_dataset(flux, args)
dataset.encode_images() trainer = Trainer(flux, dataset, args)
dataset.encode_prompts() trainer.encode_dataset()
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare # An initial generation to compare
@ -261,16 +266,16 @@ if __name__ == "__main__":
grads = None grads = None
losses = [] losses = []
tic = time.time() tic = time.time()
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state) mx.eval(loss, grads, state)
losses.append(loss.item()) losses.append(loss.item())
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
toc = time.time() toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3 peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
print( print(
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} " f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB", f"Peak mem: {peak_mem:.3f} GB",
flush=True, flush=True,

View File

@ -1,107 +1,68 @@
import json import json
from pathlib import Path from pathlib import Path
import mlx.core as mx
import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm
class Dataset: class Dataset:
def __init__(self, flux, args): def __init__(self, flux, args, data):
self.args = args self.args = args
self.flux = flux self.flux = flux
self._data = data
def __getitem__(self, index: int):
item = self._data[index]
image = item['image']
prompt = item['prompt']
return image, prompt
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class LocalDataset(Dataset):
def __init__(self, flux, args, data_file):
self.dataset_base = Path(args.dataset) self.dataset_base = Path(args.dataset)
data_file = self.dataset_base / "train.jsonl"
if not data_file.exists():
raise ValueError(f"The fine-tuning dataset 'train.jsonl' was not found in the '{args.dataset}' path.")
with open(data_file, "r") as fid: with open(data_file, "r") as fid:
self.data = [json.loads(l) for l in fid] self._data = [json.loads(l) for l in fid]
self.latents = [] super().__init__(flux, args, self._data)
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img): def __getitem__(self, index: int):
resolution = self.args.resolution item = self._data[index]
width, height = img.size image = Image.open(self.dataset_base / item["image"])
return image, item["prompt"]
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions class HuggingFaceDataset(Dataset):
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image def __init__(self, flux, args):
# rectangle. from datasets import load_dataset
width, height = crop_size df = load_dataset(args.dataset)["train"]
ratio = resolution[0] / resolution[1] self._data = df.data
r1 = (height * ratio, height) super().__init__(flux, args, df)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution def __getitem__(self, index: int):
img = img.resize(resolution, Image.LANCZOS) item = self._data[index]
return item['image'], item['prompt']
return mx.array(np.array(img))
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.data, desc="encode images"):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.data, desc="encode prompts"):
t5_tok, clip_tok = self.flux.tokenize([sample["prompt"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i: i + batch_size]
c_i = c_indices[i: i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
def load_dataset(flux, args): def load_dataset(flux, args):
return Dataset(flux, args) dataset_base = Path(args.dataset)
data_file = dataset_base / "train.jsonl"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
# print(f"load local dataset: {data_file}")
dataset = LocalDataset(flux, args, data_file)
else:
print(f"Load the Hugging Face dataset {args.dataset} .", flush=True)
# print(f"load Hugging Face dataset: {args.dataset}")
dataset = HuggingFaceDataset(flux, args)
return dataset

98
flux/flux/trainer.py Normal file
View File

@ -0,0 +1,98 @@
import mlx.core as mx
import numpy as np
from PIL import Image, ImageFile
from tqdm import tqdm
from .datasets import Dataset
from .flux import FluxPipeline
class Trainer:
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
self.flux = flux
self.dataset = dataset
self.args = args
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
for i in range(num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def _encode_prompt(self, prompt):
t5_tok, clip_tok = self.flux.tokenize([prompt])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def encode_dataset(self):
"""Encode the images & prompt in the latent space to prepare for training."""
self.flux.ae.eval()
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
self._encode_image(image, self.args.num_augmentations)
self._encode_prompt(prompt)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i: i + batch_size]
c_i = c_indices[i: i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]