diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 13102cc8..fa6d8c89 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -1,7 +1,6 @@ # Copyright © 2024 Apple Inc. import argparse -import json import time from functools import partial from pathlib import Path @@ -10,12 +9,11 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np +from PIL import Image from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce -from PIL import Image -from tqdm import tqdm -from flux import FluxPipeline, load_dataset +from flux import FluxPipeline, load_dataset, Trainer def generate_progress_images(iteration, flux, args): @@ -69,7 +67,7 @@ def setup_arg_parser(): parser.add_argument( "--model", - default="dev", + default="schnell", choices=[ "dev", "schnell", @@ -188,6 +186,7 @@ if __name__ == "__main__": optimizer = optim.Adam(learning_rate=lr_schedule) state = [flux.flow.state, optimizer.state, mx.random.state] + @partial(mx.compile, inputs=state, outputs=state) def single_step(x, t5_feat, clip_feat, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -198,12 +197,14 @@ if __name__ == "__main__": return loss + @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): return nn.value_and_grad(flux.flow, flux.training_loss)( x, t5_feat, clip_feat, guidance ) + @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -212,6 +213,7 @@ if __name__ == "__main__": grads = tree_map(lambda a, b: a + b, prev_grads, grads) return loss, grads + @partial(mx.compile, inputs=state, outputs=state) def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -227,6 +229,7 @@ if __name__ == "__main__": return loss + # We simply route to the appropriate step based on whether we have # gradients from a previous step and whether we should be performing an # update or simply computing and accumulating gradients in this step. @@ -249,10 +252,12 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - print("Create the training dataset.", flush=True) + + # print("Create the training dataset.", flush=True) dataset = load_dataset(flux, args) - dataset.encode_images() - dataset.encode_prompts() + trainer = Trainer(flux, dataset, args) + trainer.encode_dataset() + guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) # An initial generation to compare @@ -261,16 +266,16 @@ if __name__ == "__main__": grads = None losses = [] tic = time.time() - for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): + for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) mx.eval(loss, grads, state) losses.append(loss.item()) if (i + 1) % 10 == 0: toc = time.time() - peak_mem = mx.metal.get_peak_memory() / 1024**3 + peak_mem = mx.metal.get_peak_memory() / 1024 ** 3 print( - f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " + f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " f"Peak mem: {peak_mem:.3f} GB", flush=True, diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py index 6b782ca7..de6fada9 100644 --- a/flux/flux/datasets.py +++ b/flux/flux/datasets.py @@ -1,107 +1,68 @@ import json from pathlib import Path -import mlx.core as mx -import numpy as np from PIL import Image -from tqdm import tqdm class Dataset: - def __init__(self, flux, args): + def __init__(self, flux, args, data): self.args = args self.flux = flux + + self._data = data + + def __getitem__(self, index: int): + item = self._data[index] + image = item['image'] + prompt = item['prompt'] + + return image, prompt + + def __len__(self): + if self._data is None: + return 0 + return len(self._data) + + +class LocalDataset(Dataset): + + def __init__(self, flux, args, data_file): self.dataset_base = Path(args.dataset) - data_file = self.dataset_base / "train.jsonl" - if not data_file.exists(): - raise ValueError(f"The fine-tuning dataset 'train.jsonl' was not found in the '{args.dataset}' path.") with open(data_file, "r") as fid: - self.data = [json.loads(l) for l in fid] + self._data = [json.loads(l) for l in fid] - self.latents = [] - self.t5_features = [] - self.clip_features = [] + super().__init__(flux, args, self._data) - def _random_crop_resize(self, img): - resolution = self.args.resolution - width, height = img.size + def __getitem__(self, index: int): + item = self._data[index] + image = Image.open(self.dataset_base / item["image"]) + return image, item["prompt"] - a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() - # Random crop the input image between 0.8 to 1.0 of its original dimensions - crop_size = ( - max((0.8 + 0.2 * a) * width, resolution[0]), - max((0.8 + 0.2 * a) * height, resolution[1]), - ) - pan = (width - crop_size[0], height - crop_size[1]) - img = img.crop( - ( - pan[0] * b, - pan[1] * c, - crop_size[0] + pan[0] * b, - crop_size[1] + pan[1] * c, - ) - ) +class HuggingFaceDataset(Dataset): - # Fit the largest rectangle with the ratio of resolution in the image - # rectangle. - width, height = crop_size - ratio = resolution[0] / resolution[1] - r1 = (height * ratio, height) - r2 = (width, width / ratio) - r = r1 if r1[0] <= width else r2 - img = img.crop( - ( - (width - r[0]) / 2, - (height - r[1]) / 2, - (width + r[0]) / 2, - (height + r[1]) / 2, - ) - ) + def __init__(self, flux, args): + from datasets import load_dataset + df = load_dataset(args.dataset)["train"] + self._data = df.data + super().__init__(flux, args, df) - # Finally resize the image to resolution - img = img.resize(resolution, Image.LANCZOS) - - return mx.array(np.array(img)) - - def encode_images(self): - """Encode the images in the latent space to prepare for training.""" - self.flux.ae.eval() - for sample in tqdm(self.data, desc="encode images"): - input_img = Image.open(self.dataset_base / sample["image"]) - for i in range(self.args.num_augmentations): - img = self._random_crop_resize(input_img) - img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 - x_0 = self.flux.ae.encode(img[None]) - x_0 = x_0.astype(self.flux.dtype) - mx.eval(x_0) - self.latents.append(x_0) - - def encode_prompts(self): - """Pre-encode the prompts so that we don't recompute them during - training (doesn't allow finetuning the text encoders).""" - for sample in tqdm(self.data, desc="encode prompts"): - t5_tok, clip_tok = self.flux.tokenize([sample["prompt"]]) - t5_feat = self.flux.t5(t5_tok) - clip_feat = self.flux.clip(clip_tok).pooled_output - mx.eval(t5_feat, clip_feat) - self.t5_features.append(t5_feat) - self.clip_features.append(clip_feat) - - def iterate(self, batch_size): - xs = mx.concatenate(self.latents) - t5 = mx.concatenate(self.t5_features) - clip = mx.concatenate(self.clip_features) - mx.eval(xs, t5, clip) - n_aug = self.args.num_augmentations - while True: - x_indices = mx.random.permutation(len(self.latents)) - c_indices = x_indices // n_aug - for i in range(0, len(self.latents), batch_size): - x_i = x_indices[i: i + batch_size] - c_i = c_indices[i: i + batch_size] - yield xs[x_i], t5[c_i], clip[c_i] + def __getitem__(self, index: int): + item = self._data[index] + return item['image'], item['prompt'] def load_dataset(flux, args): - return Dataset(flux, args) + dataset_base = Path(args.dataset) + data_file = dataset_base / "train.jsonl" + + if data_file.exists(): + print(f"Load the local dataset {data_file} .", flush=True) + # print(f"load local dataset: {data_file}") + dataset = LocalDataset(flux, args, data_file) + else: + print(f"Load the Hugging Face dataset {args.dataset} .", flush=True) + # print(f"load Hugging Face dataset: {args.dataset}") + dataset = HuggingFaceDataset(flux, args) + + return dataset diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py new file mode 100644 index 00000000..ecfb4854 --- /dev/null +++ b/flux/flux/trainer.py @@ -0,0 +1,98 @@ +import mlx.core as mx +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm + +from .datasets import Dataset +from .flux import FluxPipeline + + +class Trainer: + + def __init__(self, flux: FluxPipeline, dataset: Dataset, args): + self.flux = flux + self.dataset = dataset + self.args = args + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * a) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * b, + pan[1] * c, + crop_size[0] + pan[0] * b, + crop_size[1] + pan[1] * c, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): + for i in range(num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def _encode_prompt(self, prompt): + t5_tok, clip_tok = self.flux.tokenize([prompt]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def encode_dataset(self): + """Encode the images & prompt in the latent space to prepare for training.""" + self.flux.ae.eval() + for image, prompt in tqdm(self.dataset, desc="encode dataset"): + self._encode_image(image, self.args.num_augmentations) + self._encode_prompt(prompt) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i: i + batch_size] + c_i = c_indices[i: i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i]