mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
FLUX: Optimize dataset loading logic (#1038)
This commit is contained in:
parent
3d62b058a4
commit
f491d473a3
@ -21,8 +21,9 @@ The dependencies are minimal, namely:
|
|||||||
|
|
||||||
- `huggingface-hub` to download the checkpoints.
|
- `huggingface-hub` to download the checkpoints.
|
||||||
- `regex` for the tokenization
|
- `regex` for the tokenization
|
||||||
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
|
- `tqdm`, `PIL`, and `numpy` for the scripts
|
||||||
- `sentencepiece` for the T5 tokenizer
|
- `sentencepiece` for the T5 tokenizer
|
||||||
|
- `datasets` for using an HF dataset directly
|
||||||
|
|
||||||
You can install all of the above with the `requirements.txt` as follows:
|
You can install all of the above with the `requirements.txt` as follows:
|
||||||
|
|
||||||
@ -118,17 +119,12 @@ Finetuning
|
|||||||
|
|
||||||
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
|
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
|
||||||
but ymmv) on a provided image dataset. The dataset folder must have an
|
but ymmv) on a provided image dataset. The dataset folder must have an
|
||||||
`index.json` file with the following format:
|
`train.jsonl` file with the following format:
|
||||||
|
|
||||||
```json
|
```jsonl
|
||||||
{
|
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
||||||
"data": [
|
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
||||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
...
|
||||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
|
||||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The training script by default trains for 600 iterations with a batch size of
|
The training script by default trains for 600 iterations with a batch size of
|
||||||
@ -150,19 +146,15 @@ The training images are the following 5 images [^2]:
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
We start by making the following `index.json` file and placing it in the same
|
We start by making the following `train.jsonl` file and placing it in the same
|
||||||
folder as the images.
|
folder as the images.
|
||||||
|
|
||||||
```json
|
```jsonl
|
||||||
{
|
{"image": "00.jpg", "prompt": "A photo of sks dog"}
|
||||||
"data": [
|
{"image": "01.jpg", "prompt": "A photo of sks dog"}
|
||||||
{"image": "00.jpg", "text": "A photo of sks dog"},
|
{"image": "02.jpg", "prompt": "A photo of sks dog"}
|
||||||
{"image": "01.jpg", "text": "A photo of sks dog"},
|
{"image": "03.jpg", "prompt": "A photo of sks dog"}
|
||||||
{"image": "02.jpg", "text": "A photo of sks dog"},
|
{"image": "04.jpg", "prompt": "A photo of sks dog"}
|
||||||
{"image": "03.jpg", "text": "A photo of sks dog"},
|
|
||||||
{"image": "04.jpg", "text": "A photo of sks dog"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Subsequently we finetune FLUX using the following command:
|
Subsequently we finetune FLUX using the following command:
|
||||||
@ -175,6 +167,17 @@ python dreambooth.py \
|
|||||||
path/to/dreambooth/dataset/dog6
|
path/to/dreambooth/dataset/dog6
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python dreambooth.py \
|
||||||
|
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||||
|
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||||
|
--lora-rank 4 --grad-accumulate 8 \
|
||||||
|
mlx-community/dreambooth-dog6
|
||||||
|
```
|
||||||
|
|
||||||
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
|
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
|
||||||
bit more than 1 hour.
|
bit more than 1 hour.
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -13,105 +12,8 @@ import numpy as np
|
|||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from flux import FluxPipeline
|
from flux import FluxPipeline, Trainer, load_dataset
|
||||||
|
|
||||||
|
|
||||||
class FinetuningDataset:
|
|
||||||
def __init__(self, flux, args):
|
|
||||||
self.args = args
|
|
||||||
self.flux = flux
|
|
||||||
self.dataset_base = Path(args.dataset)
|
|
||||||
dataset_index = self.dataset_base / "index.json"
|
|
||||||
if not dataset_index.exists():
|
|
||||||
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
|
|
||||||
with open(dataset_index, "r") as f:
|
|
||||||
self.index = json.load(f)
|
|
||||||
|
|
||||||
self.latents = []
|
|
||||||
self.t5_features = []
|
|
||||||
self.clip_features = []
|
|
||||||
|
|
||||||
def _random_crop_resize(self, img):
|
|
||||||
resolution = self.args.resolution
|
|
||||||
width, height = img.size
|
|
||||||
|
|
||||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
|
||||||
|
|
||||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
|
||||||
crop_size = (
|
|
||||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
|
||||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
|
||||||
)
|
|
||||||
pan = (width - crop_size[0], height - crop_size[1])
|
|
||||||
img = img.crop(
|
|
||||||
(
|
|
||||||
pan[0] * b,
|
|
||||||
pan[1] * c,
|
|
||||||
crop_size[0] + pan[0] * b,
|
|
||||||
crop_size[1] + pan[1] * c,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fit the largest rectangle with the ratio of resolution in the image
|
|
||||||
# rectangle.
|
|
||||||
width, height = crop_size
|
|
||||||
ratio = resolution[0] / resolution[1]
|
|
||||||
r1 = (height * ratio, height)
|
|
||||||
r2 = (width, width / ratio)
|
|
||||||
r = r1 if r1[0] <= width else r2
|
|
||||||
img = img.crop(
|
|
||||||
(
|
|
||||||
(width - r[0]) / 2,
|
|
||||||
(height - r[1]) / 2,
|
|
||||||
(width + r[0]) / 2,
|
|
||||||
(height + r[1]) / 2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Finally resize the image to resolution
|
|
||||||
img = img.resize(resolution, Image.LANCZOS)
|
|
||||||
|
|
||||||
return mx.array(np.array(img))
|
|
||||||
|
|
||||||
def encode_images(self):
|
|
||||||
"""Encode the images in the latent space to prepare for training."""
|
|
||||||
self.flux.ae.eval()
|
|
||||||
for sample in tqdm(self.index["data"]):
|
|
||||||
input_img = Image.open(self.dataset_base / sample["image"])
|
|
||||||
for i in range(self.args.num_augmentations):
|
|
||||||
img = self._random_crop_resize(input_img)
|
|
||||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
|
||||||
x_0 = self.flux.ae.encode(img[None])
|
|
||||||
x_0 = x_0.astype(self.flux.dtype)
|
|
||||||
mx.eval(x_0)
|
|
||||||
self.latents.append(x_0)
|
|
||||||
|
|
||||||
def encode_prompts(self):
|
|
||||||
"""Pre-encode the prompts so that we don't recompute them during
|
|
||||||
training (doesn't allow finetuning the text encoders)."""
|
|
||||||
for sample in tqdm(self.index["data"]):
|
|
||||||
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
|
|
||||||
t5_feat = self.flux.t5(t5_tok)
|
|
||||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
|
||||||
mx.eval(t5_feat, clip_feat)
|
|
||||||
self.t5_features.append(t5_feat)
|
|
||||||
self.clip_features.append(clip_feat)
|
|
||||||
|
|
||||||
def iterate(self, batch_size):
|
|
||||||
xs = mx.concatenate(self.latents)
|
|
||||||
t5 = mx.concatenate(self.t5_features)
|
|
||||||
clip = mx.concatenate(self.clip_features)
|
|
||||||
mx.eval(xs, t5, clip)
|
|
||||||
n_aug = self.args.num_augmentations
|
|
||||||
while True:
|
|
||||||
x_indices = mx.random.permutation(len(self.latents))
|
|
||||||
c_indices = x_indices // n_aug
|
|
||||||
for i in range(0, len(self.latents), batch_size):
|
|
||||||
x_i = x_indices[i : i + batch_size]
|
|
||||||
c_i = c_indices[i : i + batch_size]
|
|
||||||
yield xs[x_i], t5[c_i], clip[c_i]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def setup_arg_parser():
|
||||||
|
"""Set up and return the argument parser."""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Finetune Flux to generate images with a specific subject"
|
description="Finetune Flux to generate images with a specific subject"
|
||||||
)
|
)
|
||||||
@ -247,7 +150,11 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("dataset")
|
parser.add_argument("dataset")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load the model and set it up for LoRA training. We use the same random
|
# Load the model and set it up for LoRA training. We use the same random
|
||||||
@ -267,7 +174,7 @@ if __name__ == "__main__":
|
|||||||
trainable_params = tree_reduce(
|
trainable_params = tree_reduce(
|
||||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||||
)
|
)
|
||||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
|
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
|
||||||
|
|
||||||
# Set up the optimizer and training steps. The steps are a bit verbose to
|
# Set up the optimizer and training steps. The steps are a bit verbose to
|
||||||
# support gradient accumulation together with compilation.
|
# support gradient accumulation together with compilation.
|
||||||
@ -340,10 +247,10 @@ if __name__ == "__main__":
|
|||||||
x, t5_feat, clip_feat, guidance, prev_grads
|
x, t5_feat, clip_feat, guidance, prev_grads
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Create the training dataset.", flush=True)
|
dataset = load_dataset(args.dataset)
|
||||||
dataset = FinetuningDataset(flux, args)
|
trainer = Trainer(flux, dataset, args)
|
||||||
dataset.encode_images()
|
trainer.encode_dataset()
|
||||||
dataset.encode_prompts()
|
|
||||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||||
|
|
||||||
# An initial generation to compare
|
# An initial generation to compare
|
||||||
@ -352,7 +259,7 @@ if __name__ == "__main__":
|
|||||||
grads = None
|
grads = None
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
|
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||||
mx.eval(loss, grads, state)
|
mx.eval(loss, grads, state)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
@ -361,7 +268,7 @@ if __name__ == "__main__":
|
|||||||
toc = time.time()
|
toc = time.time()
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||||
print(
|
print(
|
||||||
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
|
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||||
f"It/s: {10 / (toc - tic):.3f} "
|
f"It/s: {10 / (toc - tic):.3f} "
|
||||||
f"Peak mem: {peak_mem:.3f} GB",
|
f"Peak mem: {peak_mem:.3f} GB",
|
||||||
flush=True,
|
flush=True,
|
||||||
|
@ -1,16 +1,10 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
from .datasets import Dataset, load_dataset
|
||||||
import time
|
from .flux import FluxPipeline
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
from mlx.utils import tree_unflatten
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .lora import LoRALinear
|
from .lora import LoRALinear
|
||||||
from .sampler import FluxSampler
|
from .sampler import FluxSampler
|
||||||
|
from .trainer import Trainer
|
||||||
from .utils import (
|
from .utils import (
|
||||||
load_ae,
|
load_ae,
|
||||||
load_clip,
|
load_clip,
|
||||||
@ -19,230 +13,3 @@ from .utils import (
|
|||||||
load_t5,
|
load_t5,
|
||||||
load_t5_tokenizer,
|
load_t5_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FluxPipeline:
|
|
||||||
def __init__(self, name: str, t5_padding: bool = True):
|
|
||||||
self.dtype = mx.bfloat16
|
|
||||||
self.name = name
|
|
||||||
self.t5_padding = t5_padding
|
|
||||||
|
|
||||||
self.ae = load_ae(name)
|
|
||||||
self.flow = load_flow_model(name)
|
|
||||||
self.clip = load_clip(name)
|
|
||||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
|
||||||
self.t5 = load_t5(name)
|
|
||||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
|
||||||
self.sampler = FluxSampler(name)
|
|
||||||
|
|
||||||
def ensure_models_are_loaded(self):
|
|
||||||
mx.eval(
|
|
||||||
self.ae.parameters(),
|
|
||||||
self.flow.parameters(),
|
|
||||||
self.clip.parameters(),
|
|
||||||
self.t5.parameters(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reload_text_encoders(self):
|
|
||||||
self.t5 = load_t5(self.name)
|
|
||||||
self.clip = load_clip(self.name)
|
|
||||||
|
|
||||||
def tokenize(self, text):
|
|
||||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
|
||||||
clip_tokens = self.clip_tokenizer.encode(text)
|
|
||||||
return t5_tokens, clip_tokens
|
|
||||||
|
|
||||||
def _prepare_latent_images(self, x):
|
|
||||||
b, h, w, c = x.shape
|
|
||||||
|
|
||||||
# Pack the latent image to 2x2 patches
|
|
||||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
|
||||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
|
||||||
|
|
||||||
# Create positions ids used to positionally encode each patch. Due to
|
|
||||||
# the way RoPE works, this results in an interesting positional
|
|
||||||
# encoding where parts of the feature are holding different positional
|
|
||||||
# information. Namely, the first part holds information independent of
|
|
||||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
|
||||||
# information and the last one horizontal.
|
|
||||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
|
||||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
|
||||||
x_ids = mx.stack([i, j, k], axis=-1)
|
|
||||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
|
||||||
|
|
||||||
return x, x_ids
|
|
||||||
|
|
||||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
|
||||||
# Prepare the text features
|
|
||||||
txt = self.t5(t5_tokens)
|
|
||||||
if len(txt) == 1 and n_images > 1:
|
|
||||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
|
||||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
|
||||||
|
|
||||||
# Prepare the clip text features
|
|
||||||
vec = self.clip(clip_tokens).pooled_output
|
|
||||||
if len(vec) == 1 and n_images > 1:
|
|
||||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
|
||||||
|
|
||||||
return txt, txt_ids, vec
|
|
||||||
|
|
||||||
def _denoising_loop(
|
|
||||||
self,
|
|
||||||
x_t,
|
|
||||||
x_ids,
|
|
||||||
txt,
|
|
||||||
txt_ids,
|
|
||||||
vec,
|
|
||||||
num_steps: int = 35,
|
|
||||||
guidance: float = 4.0,
|
|
||||||
start: float = 1,
|
|
||||||
stop: float = 0,
|
|
||||||
):
|
|
||||||
B = len(x_t)
|
|
||||||
|
|
||||||
def scalar(x):
|
|
||||||
return mx.full((B,), x, dtype=self.dtype)
|
|
||||||
|
|
||||||
guidance = scalar(guidance)
|
|
||||||
timesteps = self.sampler.timesteps(
|
|
||||||
num_steps,
|
|
||||||
x_t.shape[1],
|
|
||||||
start=start,
|
|
||||||
stop=stop,
|
|
||||||
)
|
|
||||||
for i in range(num_steps):
|
|
||||||
t = timesteps[i]
|
|
||||||
t_prev = timesteps[i + 1]
|
|
||||||
|
|
||||||
pred = self.flow(
|
|
||||||
img=x_t,
|
|
||||||
img_ids=x_ids,
|
|
||||||
txt=txt,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=vec,
|
|
||||||
timesteps=scalar(t),
|
|
||||||
guidance=guidance,
|
|
||||||
)
|
|
||||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
|
||||||
|
|
||||||
yield x_t
|
|
||||||
|
|
||||||
def generate_latents(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
n_images: int = 1,
|
|
||||||
num_steps: int = 35,
|
|
||||||
guidance: float = 4.0,
|
|
||||||
latent_size: Tuple[int, int] = (64, 64),
|
|
||||||
seed=None,
|
|
||||||
):
|
|
||||||
# Set the PRNG state
|
|
||||||
if seed is not None:
|
|
||||||
mx.random.seed(seed)
|
|
||||||
|
|
||||||
# Create the latent variables
|
|
||||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
|
||||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
|
||||||
|
|
||||||
# Get the conditioning
|
|
||||||
t5_tokens, clip_tokens = self.tokenize(text)
|
|
||||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
|
||||||
|
|
||||||
# Yield the conditioning for controlled evaluation by the caller
|
|
||||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
|
||||||
|
|
||||||
# Yield the latent sequences from the denoising loop
|
|
||||||
yield from self._denoising_loop(
|
|
||||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
|
||||||
h, w = latent_size
|
|
||||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
|
||||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
|
||||||
x = self.ae.decode(x)
|
|
||||||
return mx.clip(x + 1, 0, 2) * 0.5
|
|
||||||
|
|
||||||
def generate_images(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
n_images: int = 1,
|
|
||||||
num_steps: int = 35,
|
|
||||||
guidance: float = 4.0,
|
|
||||||
latent_size: Tuple[int, int] = (64, 64),
|
|
||||||
seed=None,
|
|
||||||
reload_text_encoders: bool = True,
|
|
||||||
progress: bool = True,
|
|
||||||
):
|
|
||||||
latents = self.generate_latents(
|
|
||||||
text, n_images, num_steps, guidance, latent_size, seed
|
|
||||||
)
|
|
||||||
mx.eval(next(latents))
|
|
||||||
|
|
||||||
if reload_text_encoders:
|
|
||||||
self.reload_text_encoders()
|
|
||||||
|
|
||||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
|
||||||
mx.eval(x_t)
|
|
||||||
|
|
||||||
images = []
|
|
||||||
for i in tqdm(range(len(x_t)), disable=not progress):
|
|
||||||
images.append(self.decode(x_t[i : i + 1]))
|
|
||||||
mx.eval(images[-1])
|
|
||||||
images = mx.concatenate(images, axis=0)
|
|
||||||
mx.eval(images)
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
def training_loss(
|
|
||||||
self,
|
|
||||||
x_0: mx.array,
|
|
||||||
t5_features: mx.array,
|
|
||||||
clip_features: mx.array,
|
|
||||||
guidance: mx.array,
|
|
||||||
):
|
|
||||||
# Get the text conditioning
|
|
||||||
txt = t5_features
|
|
||||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
|
||||||
vec = clip_features
|
|
||||||
|
|
||||||
# Prepare the latent input
|
|
||||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
|
||||||
|
|
||||||
# Forward process
|
|
||||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
|
||||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
|
||||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
|
||||||
x_t = mx.stop_gradient(x_t)
|
|
||||||
|
|
||||||
# Do the denoising
|
|
||||||
pred = self.flow(
|
|
||||||
img=x_t,
|
|
||||||
img_ids=x_ids,
|
|
||||||
txt=txt,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=vec,
|
|
||||||
timesteps=t,
|
|
||||||
guidance=guidance,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (pred + x_0 - eps).square().mean()
|
|
||||||
|
|
||||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
|
||||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
|
||||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
|
||||||
all_blocks.reverse()
|
|
||||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
|
||||||
for i, block in zip(range(num_blocks), all_blocks):
|
|
||||||
loras = []
|
|
||||||
for name, module in block.named_modules():
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
|
||||||
block.update_modules(tree_unflatten(loras))
|
|
||||||
|
|
||||||
def fuse_lora_layers(self):
|
|
||||||
fused_layers = []
|
|
||||||
for name, module in self.flow.named_modules():
|
|
||||||
if isinstance(module, LoRALinear):
|
|
||||||
fused_layers.append((name, module.fuse()))
|
|
||||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
|
||||||
|
75
flux/flux/datasets.py
Normal file
75
flux/flux/datasets.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset:
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalDataset(Dataset):
|
||||||
|
prompt_key = "prompt"
|
||||||
|
|
||||||
|
def __init__(self, dataset: str, data_file):
|
||||||
|
self.dataset_base = Path(dataset)
|
||||||
|
with open(data_file, "r") as fid:
|
||||||
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
item = self._data[index]
|
||||||
|
image = Image.open(self.dataset_base / item["image"])
|
||||||
|
return image, item[self.prompt_key]
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyDataset(LocalDataset):
|
||||||
|
prompt_key = "text"
|
||||||
|
|
||||||
|
def __init__(self, dataset: str):
|
||||||
|
self.dataset_base = Path(dataset)
|
||||||
|
with open(self.dataset_base / "index.json") as f:
|
||||||
|
self._data = json.load(f)["data"]
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, dataset: str):
|
||||||
|
from datasets import load_dataset as hf_load_dataset
|
||||||
|
|
||||||
|
self._df = hf_load_dataset(dataset)["train"]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._df)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
item = self._df[index]
|
||||||
|
return item["image"], item["prompt"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(dataset: str):
|
||||||
|
dataset_base = Path(dataset)
|
||||||
|
data_file = dataset_base / "train.jsonl"
|
||||||
|
legacy_file = dataset_base / "index.json"
|
||||||
|
|
||||||
|
if data_file.exists():
|
||||||
|
print(f"Load the local dataset {data_file} .", flush=True)
|
||||||
|
dataset = LocalDataset(dataset, data_file)
|
||||||
|
elif legacy_file.exists():
|
||||||
|
print(f"Load the local dataset {legacy_file} .")
|
||||||
|
print()
|
||||||
|
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
||||||
|
print(" See the README for details.")
|
||||||
|
print(flush=True)
|
||||||
|
dataset = LegacyDataset(dataset)
|
||||||
|
else:
|
||||||
|
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
||||||
|
dataset = HuggingFaceDataset(dataset)
|
||||||
|
|
||||||
|
return dataset
|
246
flux/flux/flux.py
Normal file
246
flux/flux/flux.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .lora import LoRALinear
|
||||||
|
from .sampler import FluxSampler
|
||||||
|
from .utils import (
|
||||||
|
load_ae,
|
||||||
|
load_clip,
|
||||||
|
load_clip_tokenizer,
|
||||||
|
load_flow_model,
|
||||||
|
load_t5,
|
||||||
|
load_t5_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxPipeline:
|
||||||
|
def __init__(self, name: str, t5_padding: bool = True):
|
||||||
|
self.dtype = mx.bfloat16
|
||||||
|
self.name = name
|
||||||
|
self.t5_padding = t5_padding
|
||||||
|
|
||||||
|
self.ae = load_ae(name)
|
||||||
|
self.flow = load_flow_model(name)
|
||||||
|
self.clip = load_clip(name)
|
||||||
|
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||||
|
self.t5 = load_t5(name)
|
||||||
|
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||||
|
self.sampler = FluxSampler(name)
|
||||||
|
|
||||||
|
def ensure_models_are_loaded(self):
|
||||||
|
mx.eval(
|
||||||
|
self.ae.parameters(),
|
||||||
|
self.flow.parameters(),
|
||||||
|
self.clip.parameters(),
|
||||||
|
self.t5.parameters(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reload_text_encoders(self):
|
||||||
|
self.t5 = load_t5(self.name)
|
||||||
|
self.clip = load_clip(self.name)
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
||||||
|
clip_tokens = self.clip_tokenizer.encode(text)
|
||||||
|
return t5_tokens, clip_tokens
|
||||||
|
|
||||||
|
def _prepare_latent_images(self, x):
|
||||||
|
b, h, w, c = x.shape
|
||||||
|
|
||||||
|
# Pack the latent image to 2x2 patches
|
||||||
|
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||||
|
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||||
|
|
||||||
|
# Create positions ids used to positionally encode each patch. Due to
|
||||||
|
# the way RoPE works, this results in an interesting positional
|
||||||
|
# encoding where parts of the feature are holding different positional
|
||||||
|
# information. Namely, the first part holds information independent of
|
||||||
|
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||||
|
# information and the last one horizontal.
|
||||||
|
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||||
|
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||||
|
x_ids = mx.stack([i, j, k], axis=-1)
|
||||||
|
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||||
|
|
||||||
|
return x, x_ids
|
||||||
|
|
||||||
|
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||||
|
# Prepare the text features
|
||||||
|
txt = self.t5(t5_tokens)
|
||||||
|
if len(txt) == 1 and n_images > 1:
|
||||||
|
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||||
|
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||||
|
|
||||||
|
# Prepare the clip text features
|
||||||
|
vec = self.clip(clip_tokens).pooled_output
|
||||||
|
if len(vec) == 1 and n_images > 1:
|
||||||
|
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||||
|
|
||||||
|
return txt, txt_ids, vec
|
||||||
|
|
||||||
|
def _denoising_loop(
|
||||||
|
self,
|
||||||
|
x_t,
|
||||||
|
x_ids,
|
||||||
|
txt,
|
||||||
|
txt_ids,
|
||||||
|
vec,
|
||||||
|
num_steps: int = 35,
|
||||||
|
guidance: float = 4.0,
|
||||||
|
start: float = 1,
|
||||||
|
stop: float = 0,
|
||||||
|
):
|
||||||
|
B = len(x_t)
|
||||||
|
|
||||||
|
def scalar(x):
|
||||||
|
return mx.full((B,), x, dtype=self.dtype)
|
||||||
|
|
||||||
|
guidance = scalar(guidance)
|
||||||
|
timesteps = self.sampler.timesteps(
|
||||||
|
num_steps,
|
||||||
|
x_t.shape[1],
|
||||||
|
start=start,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
for i in range(num_steps):
|
||||||
|
t = timesteps[i]
|
||||||
|
t_prev = timesteps[i + 1]
|
||||||
|
|
||||||
|
pred = self.flow(
|
||||||
|
img=x_t,
|
||||||
|
img_ids=x_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=scalar(t),
|
||||||
|
guidance=guidance,
|
||||||
|
)
|
||||||
|
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||||
|
|
||||||
|
yield x_t
|
||||||
|
|
||||||
|
def generate_latents(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
n_images: int = 1,
|
||||||
|
num_steps: int = 35,
|
||||||
|
guidance: float = 4.0,
|
||||||
|
latent_size: Tuple[int, int] = (64, 64),
|
||||||
|
seed=None,
|
||||||
|
):
|
||||||
|
# Set the PRNG state
|
||||||
|
if seed is not None:
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
# Create the latent variables
|
||||||
|
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||||
|
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||||
|
|
||||||
|
# Get the conditioning
|
||||||
|
t5_tokens, clip_tokens = self.tokenize(text)
|
||||||
|
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||||
|
|
||||||
|
# Yield the conditioning for controlled evaluation by the caller
|
||||||
|
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||||
|
|
||||||
|
# Yield the latent sequences from the denoising loop
|
||||||
|
yield from self._denoising_loop(
|
||||||
|
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||||
|
h, w = latent_size
|
||||||
|
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||||
|
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||||
|
x = self.ae.decode(x)
|
||||||
|
return mx.clip(x + 1, 0, 2) * 0.5
|
||||||
|
|
||||||
|
def generate_images(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
n_images: int = 1,
|
||||||
|
num_steps: int = 35,
|
||||||
|
guidance: float = 4.0,
|
||||||
|
latent_size: Tuple[int, int] = (64, 64),
|
||||||
|
seed=None,
|
||||||
|
reload_text_encoders: bool = True,
|
||||||
|
progress: bool = True,
|
||||||
|
):
|
||||||
|
latents = self.generate_latents(
|
||||||
|
text, n_images, num_steps, guidance, latent_size, seed
|
||||||
|
)
|
||||||
|
mx.eval(next(latents))
|
||||||
|
|
||||||
|
if reload_text_encoders:
|
||||||
|
self.reload_text_encoders()
|
||||||
|
|
||||||
|
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||||
|
mx.eval(x_t)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
|
||||||
|
images.append(self.decode(x_t[i : i + 1]))
|
||||||
|
mx.eval(images[-1])
|
||||||
|
images = mx.concatenate(images, axis=0)
|
||||||
|
mx.eval(images)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
def training_loss(
|
||||||
|
self,
|
||||||
|
x_0: mx.array,
|
||||||
|
t5_features: mx.array,
|
||||||
|
clip_features: mx.array,
|
||||||
|
guidance: mx.array,
|
||||||
|
):
|
||||||
|
# Get the text conditioning
|
||||||
|
txt = t5_features
|
||||||
|
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||||
|
vec = clip_features
|
||||||
|
|
||||||
|
# Prepare the latent input
|
||||||
|
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||||
|
|
||||||
|
# Forward process
|
||||||
|
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||||
|
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||||
|
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||||
|
x_t = mx.stop_gradient(x_t)
|
||||||
|
|
||||||
|
# Do the denoising
|
||||||
|
pred = self.flow(
|
||||||
|
img=x_t,
|
||||||
|
img_ids=x_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=t,
|
||||||
|
guidance=guidance,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (pred + x_0 - eps).square().mean()
|
||||||
|
|
||||||
|
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||||
|
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||||
|
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||||
|
all_blocks.reverse()
|
||||||
|
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||||
|
for i, block in zip(range(num_blocks), all_blocks):
|
||||||
|
loras = []
|
||||||
|
for name, module in block.named_modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||||
|
block.update_modules(tree_unflatten(loras))
|
||||||
|
|
||||||
|
def fuse_lora_layers(self):
|
||||||
|
fused_layers = []
|
||||||
|
for name, module in self.flow.named_modules():
|
||||||
|
if isinstance(module, LoRALinear):
|
||||||
|
fused_layers.append((name, module.fuse()))
|
||||||
|
self.flow.update_modules(tree_unflatten(fused_layers))
|
98
flux/flux/trainer.py
Normal file
98
flux/flux/trainer.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .datasets import Dataset
|
||||||
|
from .flux import FluxPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
|
||||||
|
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
||||||
|
self.flux = flux
|
||||||
|
self.dataset = dataset
|
||||||
|
self.args = args
|
||||||
|
self.latents = []
|
||||||
|
self.t5_features = []
|
||||||
|
self.clip_features = []
|
||||||
|
|
||||||
|
def _random_crop_resize(self, img):
|
||||||
|
resolution = self.args.resolution
|
||||||
|
width, height = img.size
|
||||||
|
|
||||||
|
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||||
|
|
||||||
|
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||||
|
crop_size = (
|
||||||
|
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||||
|
max((0.8 + 0.2 * b) * height, resolution[1]),
|
||||||
|
)
|
||||||
|
pan = (width - crop_size[0], height - crop_size[1])
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
pan[0] * c,
|
||||||
|
pan[1] * d,
|
||||||
|
crop_size[0] + pan[0] * c,
|
||||||
|
crop_size[1] + pan[1] * d,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fit the largest rectangle with the ratio of resolution in the image
|
||||||
|
# rectangle.
|
||||||
|
width, height = crop_size
|
||||||
|
ratio = resolution[0] / resolution[1]
|
||||||
|
r1 = (height * ratio, height)
|
||||||
|
r2 = (width, width / ratio)
|
||||||
|
r = r1 if r1[0] <= width else r2
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
(width - r[0]) / 2,
|
||||||
|
(height - r[1]) / 2,
|
||||||
|
(width + r[0]) / 2,
|
||||||
|
(height + r[1]) / 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally resize the image to resolution
|
||||||
|
img = img.resize(resolution, Image.LANCZOS)
|
||||||
|
|
||||||
|
return mx.array(np.array(img))
|
||||||
|
|
||||||
|
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
||||||
|
for i in range(num_augmentations):
|
||||||
|
img = self._random_crop_resize(input_img)
|
||||||
|
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||||
|
x_0 = self.flux.ae.encode(img[None])
|
||||||
|
x_0 = x_0.astype(self.flux.dtype)
|
||||||
|
mx.eval(x_0)
|
||||||
|
self.latents.append(x_0)
|
||||||
|
|
||||||
|
def _encode_prompt(self, prompt):
|
||||||
|
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
||||||
|
t5_feat = self.flux.t5(t5_tok)
|
||||||
|
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||||
|
mx.eval(t5_feat, clip_feat)
|
||||||
|
self.t5_features.append(t5_feat)
|
||||||
|
self.clip_features.append(clip_feat)
|
||||||
|
|
||||||
|
def encode_dataset(self):
|
||||||
|
"""Encode the images & prompt in the latent space to prepare for training."""
|
||||||
|
self.flux.ae.eval()
|
||||||
|
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
||||||
|
self._encode_image(image, self.args.num_augmentations)
|
||||||
|
self._encode_prompt(prompt)
|
||||||
|
|
||||||
|
def iterate(self, batch_size):
|
||||||
|
xs = mx.concatenate(self.latents)
|
||||||
|
t5 = mx.concatenate(self.t5_features)
|
||||||
|
clip = mx.concatenate(self.clip_features)
|
||||||
|
mx.eval(xs, t5, clip)
|
||||||
|
n_aug = self.args.num_augmentations
|
||||||
|
while True:
|
||||||
|
x_indices = mx.random.permutation(len(self.latents))
|
||||||
|
c_indices = x_indices // n_aug
|
||||||
|
for i in range(0, len(self.latents), batch_size):
|
||||||
|
x_i = x_indices[i : i + batch_size]
|
||||||
|
c_i = c_indices[i : i + batch_size]
|
||||||
|
yield xs[x_i], t5[c_i], clip[c_i]
|
Loading…
Reference in New Issue
Block a user