mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 19:06:37 +08:00
FLUX: fix pre-commit lint
This commit is contained in:
parent
082b27ffb2
commit
7a20389c06
@ -9,11 +9,11 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from flux import FluxPipeline, load_dataset, Trainer
|
from flux import FluxPipeline, Trainer, load_dataset
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
@ -186,7 +186,6 @@ if __name__ == "__main__":
|
|||||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def single_step(x, t5_feat, clip_feat, guidance):
|
def single_step(x, t5_feat, clip_feat, guidance):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -197,14 +196,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
||||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
x, t5_feat, clip_feat, guidance
|
x, t5_feat, clip_feat, guidance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -213,7 +210,6 @@ if __name__ == "__main__":
|
|||||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||||
return loss, grads
|
return loss, grads
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -229,7 +225,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
# We simply route to the appropriate step based on whether we have
|
# We simply route to the appropriate step based on whether we have
|
||||||
# gradients from a previous step and whether we should be performing an
|
# gradients from a previous step and whether we should be performing an
|
||||||
# update or simply computing and accumulating gradients in this step.
|
# update or simply computing and accumulating gradients in this step.
|
||||||
@ -252,7 +247,6 @@ if __name__ == "__main__":
|
|||||||
x, t5_feat, clip_feat, guidance, prev_grads
|
x, t5_feat, clip_feat, guidance, prev_grads
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# print("Create the training dataset.", flush=True)
|
# print("Create the training dataset.", flush=True)
|
||||||
dataset = load_dataset(flux, args)
|
dataset = load_dataset(flux, args)
|
||||||
trainer = Trainer(flux, dataset, args)
|
trainer = Trainer(flux, dataset, args)
|
||||||
@ -273,7 +267,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
|
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||||
print(
|
print(
|
||||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||||
f"It/s: {10 / (toc - tic):.3f} "
|
f"It/s: {10 / (toc - tic):.3f} "
|
||||||
|
@ -13,8 +13,8 @@ class Dataset:
|
|||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
item = self._data[index]
|
item = self._data[index]
|
||||||
image = item['image']
|
image = item["image"]
|
||||||
prompt = item['prompt']
|
prompt = item["prompt"]
|
||||||
|
|
||||||
return image, prompt
|
return image, prompt
|
||||||
|
|
||||||
@ -43,13 +43,14 @@ class HuggingFaceDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(self, flux, args):
|
def __init__(self, flux, args):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
df = load_dataset(args.dataset)["train"]
|
df = load_dataset(args.dataset)["train"]
|
||||||
self._data = df.data
|
self._data = df.data
|
||||||
super().__init__(flux, args, df)
|
super().__init__(flux, args, df)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
item = self._data[index]
|
item = self._data[index]
|
||||||
return item['image'], item['prompt']
|
return item["image"], item["prompt"]
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(flux, args):
|
def load_dataset(flux, args):
|
||||||
|
@ -185,7 +185,7 @@ class FluxPipeline:
|
|||||||
|
|
||||||
images = []
|
images = []
|
||||||
for i in tqdm(range(len(x_t)), disable=not progress):
|
for i in tqdm(range(len(x_t)), disable=not progress):
|
||||||
images.append(self.decode(x_t[i: i + 1]))
|
images.append(self.decode(x_t[i : i + 1]))
|
||||||
mx.eval(images[-1])
|
mx.eval(images[-1])
|
||||||
images = mx.concatenate(images, axis=0)
|
images = mx.concatenate(images, axis=0)
|
||||||
mx.eval(images)
|
mx.eval(images)
|
||||||
|
@ -93,6 +93,6 @@ class Trainer:
|
|||||||
x_indices = mx.random.permutation(len(self.latents))
|
x_indices = mx.random.permutation(len(self.latents))
|
||||||
c_indices = x_indices // n_aug
|
c_indices = x_indices // n_aug
|
||||||
for i in range(0, len(self.latents), batch_size):
|
for i in range(0, len(self.latents), batch_size):
|
||||||
x_i = x_indices[i: i + batch_size]
|
x_i = x_indices[i : i + batch_size]
|
||||||
c_i = c_indices[i: i + batch_size]
|
c_i = c_indices[i : i + batch_size]
|
||||||
yield xs[x_i], t5[c_i], clip[c_i]
|
yield xs[x_i], t5[c_i], clip[c_i]
|
||||||
|
Loading…
Reference in New Issue
Block a user