mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 03:19:23 +08:00
FLUX: support huggingface dataset
This commit is contained in:
parent
ca88343118
commit
b0de67ec03
@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -10,12 +9,11 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from flux import FluxPipeline, load_dataset
|
from flux import FluxPipeline, load_dataset, Trainer
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
@ -69,7 +67,7 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="dev",
|
default="schnell",
|
||||||
choices=[
|
choices=[
|
||||||
"dev",
|
"dev",
|
||||||
"schnell",
|
"schnell",
|
||||||
@ -188,6 +186,7 @@ if __name__ == "__main__":
|
|||||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def single_step(x, t5_feat, clip_feat, guidance):
|
def single_step(x, t5_feat, clip_feat, guidance):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -198,12 +197,14 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
||||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
x, t5_feat, clip_feat, guidance
|
x, t5_feat, clip_feat, guidance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -212,6 +213,7 @@ if __name__ == "__main__":
|
|||||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||||
return loss, grads
|
return loss, grads
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -227,6 +229,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
# We simply route to the appropriate step based on whether we have
|
# We simply route to the appropriate step based on whether we have
|
||||||
# gradients from a previous step and whether we should be performing an
|
# gradients from a previous step and whether we should be performing an
|
||||||
# update or simply computing and accumulating gradients in this step.
|
# update or simply computing and accumulating gradients in this step.
|
||||||
@ -249,10 +252,12 @@ if __name__ == "__main__":
|
|||||||
x, t5_feat, clip_feat, guidance, prev_grads
|
x, t5_feat, clip_feat, guidance, prev_grads
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Create the training dataset.", flush=True)
|
|
||||||
|
# print("Create the training dataset.", flush=True)
|
||||||
dataset = load_dataset(flux, args)
|
dataset = load_dataset(flux, args)
|
||||||
dataset.encode_images()
|
trainer = Trainer(flux, dataset, args)
|
||||||
dataset.encode_prompts()
|
trainer.encode_dataset()
|
||||||
|
|
||||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||||
|
|
||||||
# An initial generation to compare
|
# An initial generation to compare
|
||||||
@ -261,16 +266,16 @@ if __name__ == "__main__":
|
|||||||
grads = None
|
grads = None
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
|
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||||
mx.eval(loss, grads, state)
|
mx.eval(loss, grads, state)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
|
||||||
print(
|
print(
|
||||||
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
|
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||||
f"It/s: {10 / (toc - tic):.3f} "
|
f"It/s: {10 / (toc - tic):.3f} "
|
||||||
f"Peak mem: {peak_mem:.3f} GB",
|
f"Peak mem: {peak_mem:.3f} GB",
|
||||||
flush=True,
|
flush=True,
|
||||||
|
@ -1,107 +1,68 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
def __init__(self, flux, args):
|
def __init__(self, flux, args, data):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.flux = flux
|
self.flux = flux
|
||||||
|
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
item = self._data[index]
|
||||||
|
image = item['image']
|
||||||
|
prompt = item['prompt']
|
||||||
|
|
||||||
|
return image, prompt
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self._data is None:
|
||||||
|
return 0
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, flux, args, data_file):
|
||||||
self.dataset_base = Path(args.dataset)
|
self.dataset_base = Path(args.dataset)
|
||||||
data_file = self.dataset_base / "train.jsonl"
|
|
||||||
if not data_file.exists():
|
|
||||||
raise ValueError(f"The fine-tuning dataset 'train.jsonl' was not found in the '{args.dataset}' path.")
|
|
||||||
with open(data_file, "r") as fid:
|
with open(data_file, "r") as fid:
|
||||||
self.data = [json.loads(l) for l in fid]
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
|
||||||
self.latents = []
|
super().__init__(flux, args, self._data)
|
||||||
self.t5_features = []
|
|
||||||
self.clip_features = []
|
|
||||||
|
|
||||||
def _random_crop_resize(self, img):
|
def __getitem__(self, index: int):
|
||||||
resolution = self.args.resolution
|
item = self._data[index]
|
||||||
width, height = img.size
|
image = Image.open(self.dataset_base / item["image"])
|
||||||
|
return image, item["prompt"]
|
||||||
|
|
||||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
|
||||||
|
|
||||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
class HuggingFaceDataset(Dataset):
|
||||||
crop_size = (
|
|
||||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
|
||||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
|
||||||
)
|
|
||||||
pan = (width - crop_size[0], height - crop_size[1])
|
|
||||||
img = img.crop(
|
|
||||||
(
|
|
||||||
pan[0] * b,
|
|
||||||
pan[1] * c,
|
|
||||||
crop_size[0] + pan[0] * b,
|
|
||||||
crop_size[1] + pan[1] * c,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fit the largest rectangle with the ratio of resolution in the image
|
def __init__(self, flux, args):
|
||||||
# rectangle.
|
from datasets import load_dataset
|
||||||
width, height = crop_size
|
df = load_dataset(args.dataset)["train"]
|
||||||
ratio = resolution[0] / resolution[1]
|
self._data = df.data
|
||||||
r1 = (height * ratio, height)
|
super().__init__(flux, args, df)
|
||||||
r2 = (width, width / ratio)
|
|
||||||
r = r1 if r1[0] <= width else r2
|
|
||||||
img = img.crop(
|
|
||||||
(
|
|
||||||
(width - r[0]) / 2,
|
|
||||||
(height - r[1]) / 2,
|
|
||||||
(width + r[0]) / 2,
|
|
||||||
(height + r[1]) / 2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Finally resize the image to resolution
|
def __getitem__(self, index: int):
|
||||||
img = img.resize(resolution, Image.LANCZOS)
|
item = self._data[index]
|
||||||
|
return item['image'], item['prompt']
|
||||||
return mx.array(np.array(img))
|
|
||||||
|
|
||||||
def encode_images(self):
|
|
||||||
"""Encode the images in the latent space to prepare for training."""
|
|
||||||
self.flux.ae.eval()
|
|
||||||
for sample in tqdm(self.data, desc="encode images"):
|
|
||||||
input_img = Image.open(self.dataset_base / sample["image"])
|
|
||||||
for i in range(self.args.num_augmentations):
|
|
||||||
img = self._random_crop_resize(input_img)
|
|
||||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
|
||||||
x_0 = self.flux.ae.encode(img[None])
|
|
||||||
x_0 = x_0.astype(self.flux.dtype)
|
|
||||||
mx.eval(x_0)
|
|
||||||
self.latents.append(x_0)
|
|
||||||
|
|
||||||
def encode_prompts(self):
|
|
||||||
"""Pre-encode the prompts so that we don't recompute them during
|
|
||||||
training (doesn't allow finetuning the text encoders)."""
|
|
||||||
for sample in tqdm(self.data, desc="encode prompts"):
|
|
||||||
t5_tok, clip_tok = self.flux.tokenize([sample["prompt"]])
|
|
||||||
t5_feat = self.flux.t5(t5_tok)
|
|
||||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
|
||||||
mx.eval(t5_feat, clip_feat)
|
|
||||||
self.t5_features.append(t5_feat)
|
|
||||||
self.clip_features.append(clip_feat)
|
|
||||||
|
|
||||||
def iterate(self, batch_size):
|
|
||||||
xs = mx.concatenate(self.latents)
|
|
||||||
t5 = mx.concatenate(self.t5_features)
|
|
||||||
clip = mx.concatenate(self.clip_features)
|
|
||||||
mx.eval(xs, t5, clip)
|
|
||||||
n_aug = self.args.num_augmentations
|
|
||||||
while True:
|
|
||||||
x_indices = mx.random.permutation(len(self.latents))
|
|
||||||
c_indices = x_indices // n_aug
|
|
||||||
for i in range(0, len(self.latents), batch_size):
|
|
||||||
x_i = x_indices[i: i + batch_size]
|
|
||||||
c_i = c_indices[i: i + batch_size]
|
|
||||||
yield xs[x_i], t5[c_i], clip[c_i]
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(flux, args):
|
def load_dataset(flux, args):
|
||||||
return Dataset(flux, args)
|
dataset_base = Path(args.dataset)
|
||||||
|
data_file = dataset_base / "train.jsonl"
|
||||||
|
|
||||||
|
if data_file.exists():
|
||||||
|
print(f"Load the local dataset {data_file} .", flush=True)
|
||||||
|
# print(f"load local dataset: {data_file}")
|
||||||
|
dataset = LocalDataset(flux, args, data_file)
|
||||||
|
else:
|
||||||
|
print(f"Load the Hugging Face dataset {args.dataset} .", flush=True)
|
||||||
|
# print(f"load Hugging Face dataset: {args.dataset}")
|
||||||
|
dataset = HuggingFaceDataset(flux, args)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
98
flux/flux/trainer.py
Normal file
98
flux/flux/trainer.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .datasets import Dataset
|
||||||
|
from .flux import FluxPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
|
||||||
|
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
||||||
|
self.flux = flux
|
||||||
|
self.dataset = dataset
|
||||||
|
self.args = args
|
||||||
|
self.latents = []
|
||||||
|
self.t5_features = []
|
||||||
|
self.clip_features = []
|
||||||
|
|
||||||
|
def _random_crop_resize(self, img):
|
||||||
|
resolution = self.args.resolution
|
||||||
|
width, height = img.size
|
||||||
|
|
||||||
|
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||||
|
|
||||||
|
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||||
|
crop_size = (
|
||||||
|
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||||
|
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||||
|
)
|
||||||
|
pan = (width - crop_size[0], height - crop_size[1])
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
pan[0] * b,
|
||||||
|
pan[1] * c,
|
||||||
|
crop_size[0] + pan[0] * b,
|
||||||
|
crop_size[1] + pan[1] * c,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fit the largest rectangle with the ratio of resolution in the image
|
||||||
|
# rectangle.
|
||||||
|
width, height = crop_size
|
||||||
|
ratio = resolution[0] / resolution[1]
|
||||||
|
r1 = (height * ratio, height)
|
||||||
|
r2 = (width, width / ratio)
|
||||||
|
r = r1 if r1[0] <= width else r2
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
(width - r[0]) / 2,
|
||||||
|
(height - r[1]) / 2,
|
||||||
|
(width + r[0]) / 2,
|
||||||
|
(height + r[1]) / 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally resize the image to resolution
|
||||||
|
img = img.resize(resolution, Image.LANCZOS)
|
||||||
|
|
||||||
|
return mx.array(np.array(img))
|
||||||
|
|
||||||
|
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
||||||
|
for i in range(num_augmentations):
|
||||||
|
img = self._random_crop_resize(input_img)
|
||||||
|
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||||
|
x_0 = self.flux.ae.encode(img[None])
|
||||||
|
x_0 = x_0.astype(self.flux.dtype)
|
||||||
|
mx.eval(x_0)
|
||||||
|
self.latents.append(x_0)
|
||||||
|
|
||||||
|
def _encode_prompt(self, prompt):
|
||||||
|
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
||||||
|
t5_feat = self.flux.t5(t5_tok)
|
||||||
|
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||||
|
mx.eval(t5_feat, clip_feat)
|
||||||
|
self.t5_features.append(t5_feat)
|
||||||
|
self.clip_features.append(clip_feat)
|
||||||
|
|
||||||
|
def encode_dataset(self):
|
||||||
|
"""Encode the images & prompt in the latent space to prepare for training."""
|
||||||
|
self.flux.ae.eval()
|
||||||
|
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
||||||
|
self._encode_image(image, self.args.num_augmentations)
|
||||||
|
self._encode_prompt(prompt)
|
||||||
|
|
||||||
|
def iterate(self, batch_size):
|
||||||
|
xs = mx.concatenate(self.latents)
|
||||||
|
t5 = mx.concatenate(self.t5_features)
|
||||||
|
clip = mx.concatenate(self.clip_features)
|
||||||
|
mx.eval(xs, t5, clip)
|
||||||
|
n_aug = self.args.num_augmentations
|
||||||
|
while True:
|
||||||
|
x_indices = mx.random.permutation(len(self.latents))
|
||||||
|
c_indices = x_indices // n_aug
|
||||||
|
for i in range(0, len(self.latents), batch_size):
|
||||||
|
x_i = x_indices[i: i + batch_size]
|
||||||
|
c_i = c_indices[i: i + batch_size]
|
||||||
|
yield xs[x_i], t5[c_i], clip[c_i]
|
Loading…
Reference in New Issue
Block a user