mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
FLUX: support huggingface dataset
This commit is contained in:
parent
ca88343118
commit
b0de67ec03
@ -1,7 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@ -10,12 +9,11 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline, load_dataset
|
||||
from flux import FluxPipeline, load_dataset, Trainer
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
@ -69,7 +67,7 @@ def setup_arg_parser():
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="dev",
|
||||
default="schnell",
|
||||
choices=[
|
||||
"dev",
|
||||
"schnell",
|
||||
@ -188,6 +186,7 @@ if __name__ == "__main__":
|
||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def single_step(x, t5_feat, clip_feat, guidance):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
@ -198,12 +197,14 @@ if __name__ == "__main__":
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
@ -212,6 +213,7 @@ if __name__ == "__main__":
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
return loss, grads
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
@ -227,6 +229,7 @@ if __name__ == "__main__":
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# We simply route to the appropriate step based on whether we have
|
||||
# gradients from a previous step and whether we should be performing an
|
||||
# update or simply computing and accumulating gradients in this step.
|
||||
@ -249,10 +252,12 @@ if __name__ == "__main__":
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
)
|
||||
|
||||
print("Create the training dataset.", flush=True)
|
||||
|
||||
# print("Create the training dataset.", flush=True)
|
||||
dataset = load_dataset(flux, args)
|
||||
dataset.encode_images()
|
||||
dataset.encode_prompts()
|
||||
trainer = Trainer(flux, dataset, args)
|
||||
trainer.encode_dataset()
|
||||
|
||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
@ -261,16 +266,16 @@ if __name__ == "__main__":
|
||||
grads = None
|
||||
losses = []
|
||||
tic = time.time()
|
||||
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
|
||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||
mx.eval(loss, grads, state)
|
||||
losses.append(loss.item())
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
toc = time.time()
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
|
||||
print(
|
||||
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
f"Peak mem: {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
|
@ -1,107 +1,68 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, flux, args):
|
||||
def __init__(self, flux, args, data):
|
||||
self.args = args
|
||||
self.flux = flux
|
||||
|
||||
self._data = data
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
image = item['image']
|
||||
prompt = item['prompt']
|
||||
|
||||
return image, prompt
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class LocalDataset(Dataset):
|
||||
|
||||
def __init__(self, flux, args, data_file):
|
||||
self.dataset_base = Path(args.dataset)
|
||||
data_file = self.dataset_base / "train.jsonl"
|
||||
if not data_file.exists():
|
||||
raise ValueError(f"The fine-tuning dataset 'train.jsonl' was not found in the '{args.dataset}' path.")
|
||||
with open(data_file, "r") as fid:
|
||||
self.data = [json.loads(l) for l in fid]
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
super().__init__(flux, args, self._data)
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
image = Image.open(self.dataset_base / item["image"])
|
||||
return image, item["prompt"]
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * b,
|
||||
pan[1] * c,
|
||||
crop_size[0] + pan[0] * b,
|
||||
crop_size[1] + pan[1] * c,
|
||||
)
|
||||
)
|
||||
class HuggingFaceDataset(Dataset):
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
def __init__(self, flux, args):
|
||||
from datasets import load_dataset
|
||||
df = load_dataset(args.dataset)["train"]
|
||||
self._data = df.data
|
||||
super().__init__(flux, args, df)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def encode_images(self):
|
||||
"""Encode the images in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for sample in tqdm(self.data, desc="encode images"):
|
||||
input_img = Image.open(self.dataset_base / sample["image"])
|
||||
for i in range(self.args.num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def encode_prompts(self):
|
||||
"""Pre-encode the prompts so that we don't recompute them during
|
||||
training (doesn't allow finetuning the text encoders)."""
|
||||
for sample in tqdm(self.data, desc="encode prompts"):
|
||||
t5_tok, clip_tok = self.flux.tokenize([sample["prompt"]])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i: i + batch_size]
|
||||
c_i = c_indices[i: i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
return item['image'], item['prompt']
|
||||
|
||||
|
||||
def load_dataset(flux, args):
|
||||
return Dataset(flux, args)
|
||||
dataset_base = Path(args.dataset)
|
||||
data_file = dataset_base / "train.jsonl"
|
||||
|
||||
if data_file.exists():
|
||||
print(f"Load the local dataset {data_file} .", flush=True)
|
||||
# print(f"load local dataset: {data_file}")
|
||||
dataset = LocalDataset(flux, args, data_file)
|
||||
else:
|
||||
print(f"Load the Hugging Face dataset {args.dataset} .", flush=True)
|
||||
# print(f"load Hugging Face dataset: {args.dataset}")
|
||||
dataset = HuggingFaceDataset(flux, args)
|
||||
|
||||
return dataset
|
||||
|
98
flux/flux/trainer.py
Normal file
98
flux/flux/trainer.py
Normal file
@ -0,0 +1,98 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from tqdm import tqdm
|
||||
|
||||
from .datasets import Dataset
|
||||
from .flux import FluxPipeline
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
||||
self.flux = flux
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * b,
|
||||
pan[1] * c,
|
||||
crop_size[0] + pan[0] * b,
|
||||
crop_size[1] + pan[1] * c,
|
||||
)
|
||||
)
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
||||
for i in range(num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def _encode_prompt(self, prompt):
|
||||
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def encode_dataset(self):
|
||||
"""Encode the images & prompt in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
||||
self._encode_image(image, self.args.num_augmentations)
|
||||
self._encode_prompt(prompt)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i: i + batch_size]
|
||||
c_i = c_indices[i: i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
Loading…
Reference in New Issue
Block a user