From 7a20389c06a87cd7c14bb2b0554f1ae14e774e69 Mon Sep 17 00:00:00 2001 From: madroid Date: Sun, 13 Oct 2024 01:57:23 +0800 Subject: [PATCH] FLUX: fix pre-commit lint --- flux/dreambooth.py | 12 +++--------- flux/flux/datasets.py | 7 ++++--- flux/flux/flux.py | 2 +- flux/flux/trainer.py | 4 ++-- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index fa6d8c89..444f6a1e 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -9,11 +9,11 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np -from PIL import Image from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce +from PIL import Image -from flux import FluxPipeline, load_dataset, Trainer +from flux import FluxPipeline, Trainer, load_dataset def generate_progress_images(iteration, flux, args): @@ -186,7 +186,6 @@ if __name__ == "__main__": optimizer = optim.Adam(learning_rate=lr_schedule) state = [flux.flow.state, optimizer.state, mx.random.state] - @partial(mx.compile, inputs=state, outputs=state) def single_step(x, t5_feat, clip_feat, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -197,14 +196,12 @@ if __name__ == "__main__": return loss - @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): return nn.value_and_grad(flux.flow, flux.training_loss)( x, t5_feat, clip_feat, guidance ) - @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -213,7 +210,6 @@ if __name__ == "__main__": grads = tree_map(lambda a, b: a + b, prev_grads, grads) return loss, grads - @partial(mx.compile, inputs=state, outputs=state) def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -229,7 +225,6 @@ if __name__ == "__main__": return loss - # We simply route to the appropriate step based on whether we have # gradients from a previous step and whether we should be performing an # update or simply computing and accumulating gradients in this step. @@ -252,7 +247,6 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - # print("Create the training dataset.", flush=True) dataset = load_dataset(flux, args) trainer = Trainer(flux, dataset, args) @@ -273,7 +267,7 @@ if __name__ == "__main__": if (i + 1) % 10 == 0: toc = time.time() - peak_mem = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem = mx.metal.get_peak_memory() / 1024**3 print( f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py index de6fada9..5a845d82 100644 --- a/flux/flux/datasets.py +++ b/flux/flux/datasets.py @@ -13,8 +13,8 @@ class Dataset: def __getitem__(self, index: int): item = self._data[index] - image = item['image'] - prompt = item['prompt'] + image = item["image"] + prompt = item["prompt"] return image, prompt @@ -43,13 +43,14 @@ class HuggingFaceDataset(Dataset): def __init__(self, flux, args): from datasets import load_dataset + df = load_dataset(args.dataset)["train"] self._data = df.data super().__init__(flux, args, df) def __getitem__(self, index: int): item = self._data[index] - return item['image'], item['prompt'] + return item["image"], item["prompt"] def load_dataset(flux, args): diff --git a/flux/flux/flux.py b/flux/flux/flux.py index b29a3c55..c9f23f7d 100644 --- a/flux/flux/flux.py +++ b/flux/flux/flux.py @@ -185,7 +185,7 @@ class FluxPipeline: images = [] for i in tqdm(range(len(x_t)), disable=not progress): - images.append(self.decode(x_t[i: i + 1])) + images.append(self.decode(x_t[i : i + 1])) mx.eval(images[-1]) images = mx.concatenate(images, axis=0) mx.eval(images) diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py index ecfb4854..ed645941 100644 --- a/flux/flux/trainer.py +++ b/flux/flux/trainer.py @@ -93,6 +93,6 @@ class Trainer: x_indices = mx.random.permutation(len(self.latents)) c_indices = x_indices // n_aug for i in range(0, len(self.latents), batch_size): - x_i = x_indices[i: i + batch_size] - c_i = c_indices[i: i + batch_size] + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] yield xs[x_i], t5[c_i], clip[c_i]