FLUX: fix pre-commit lint

This commit is contained in:
madroid 2024-10-13 01:57:23 +08:00
parent 082b27ffb2
commit 7a20389c06
4 changed files with 10 additions and 15 deletions

View File

@ -9,11 +9,11 @@ import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from PIL import Image
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from flux import FluxPipeline, load_dataset, Trainer
from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args):
@ -186,7 +186,6 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -197,14 +196,12 @@ if __name__ == "__main__":
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -213,7 +210,6 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -229,7 +225,6 @@ if __name__ == "__main__":
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
@ -252,7 +247,6 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads
)
# print("Create the training dataset.", flush=True)
dataset = load_dataset(flux, args)
trainer = Trainer(flux, dataset, args)
@ -273,7 +267,7 @@ if __name__ == "__main__":
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "

View File

@ -13,8 +13,8 @@ class Dataset:
def __getitem__(self, index: int):
item = self._data[index]
image = item['image']
prompt = item['prompt']
image = item["image"]
prompt = item["prompt"]
return image, prompt
@ -43,13 +43,14 @@ class HuggingFaceDataset(Dataset):
def __init__(self, flux, args):
from datasets import load_dataset
df = load_dataset(args.dataset)["train"]
self._data = df.data
super().__init__(flux, args, df)
def __getitem__(self, index: int):
item = self._data[index]
return item['image'], item['prompt']
return item["image"], item["prompt"]
def load_dataset(flux, args):

View File

@ -185,7 +185,7 @@ class FluxPipeline:
images = []
for i in tqdm(range(len(x_t)), disable=not progress):
images.append(self.decode(x_t[i: i + 1]))
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)

View File

@ -93,6 +93,6 @@ class Trainer:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i: i + batch_size]
c_i = c_indices[i: i + batch_size]
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]