FLUX: fix pre-commit lint

This commit is contained in:
madroid 2024-10-13 01:57:23 +08:00
parent 082b27ffb2
commit 7a20389c06
4 changed files with 10 additions and 15 deletions

View File

@ -9,11 +9,11 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
from PIL import Image
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from flux import FluxPipeline, load_dataset, Trainer from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@ -186,7 +186,6 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule) optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state] state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance): def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -197,14 +196,12 @@ if __name__ == "__main__":
return loss return loss
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)( return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance x, t5_feat, clip_feat, guidance
) )
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -213,7 +210,6 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads) grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads return loss, grads
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -229,7 +225,6 @@ if __name__ == "__main__":
return loss return loss
# We simply route to the appropriate step based on whether we have # We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an # gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step. # update or simply computing and accumulating gradients in this step.
@ -252,7 +247,6 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads x, t5_feat, clip_feat, guidance, prev_grads
) )
# print("Create the training dataset.", flush=True) # print("Create the training dataset.", flush=True)
dataset = load_dataset(flux, args) dataset = load_dataset(flux, args)
trainer = Trainer(flux, dataset, args) trainer = Trainer(flux, dataset, args)
@ -273,7 +267,7 @@ if __name__ == "__main__":
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
toc = time.time() toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3 peak_mem = mx.metal.get_peak_memory() / 1024**3
print( print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} " f"It/s: {10 / (toc - tic):.3f} "

View File

@ -13,8 +13,8 @@ class Dataset:
def __getitem__(self, index: int): def __getitem__(self, index: int):
item = self._data[index] item = self._data[index]
image = item['image'] image = item["image"]
prompt = item['prompt'] prompt = item["prompt"]
return image, prompt return image, prompt
@ -43,13 +43,14 @@ class HuggingFaceDataset(Dataset):
def __init__(self, flux, args): def __init__(self, flux, args):
from datasets import load_dataset from datasets import load_dataset
df = load_dataset(args.dataset)["train"] df = load_dataset(args.dataset)["train"]
self._data = df.data self._data = df.data
super().__init__(flux, args, df) super().__init__(flux, args, df)
def __getitem__(self, index: int): def __getitem__(self, index: int):
item = self._data[index] item = self._data[index]
return item['image'], item['prompt'] return item["image"], item["prompt"]
def load_dataset(flux, args): def load_dataset(flux, args):

View File

@ -185,7 +185,7 @@ class FluxPipeline:
images = [] images = []
for i in tqdm(range(len(x_t)), disable=not progress): for i in tqdm(range(len(x_t)), disable=not progress):
images.append(self.decode(x_t[i: i + 1])) images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1]) mx.eval(images[-1])
images = mx.concatenate(images, axis=0) images = mx.concatenate(images, axis=0)
mx.eval(images) mx.eval(images)

View File

@ -93,6 +93,6 @@ class Trainer:
x_indices = mx.random.permutation(len(self.latents)) x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size): for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i: i + batch_size] x_i = x_indices[i : i + batch_size]
c_i = c_indices[i: i + batch_size] c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i] yield xs[x_i], t5[c_i], clip[c_i]