FLUX: support huggingface dataset

This commit is contained in:
madroid
2024-10-13 01:38:58 +08:00
parent ca88343118
commit b0de67ec03
3 changed files with 162 additions and 98 deletions

View File

@@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import time
from functools import partial
from pathlib import Path
@@ -10,12 +9,11 @@ import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from PIL import Image
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline, load_dataset
from flux import FluxPipeline, load_dataset, Trainer
def generate_progress_images(iteration, flux, args):
@@ -69,7 +67,7 @@ def setup_arg_parser():
parser.add_argument(
"--model",
default="dev",
default="schnell",
choices=[
"dev",
"schnell",
@@ -188,6 +186,7 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@@ -198,12 +197,14 @@ if __name__ == "__main__":
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@@ -212,6 +213,7 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@@ -227,6 +229,7 @@ if __name__ == "__main__":
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
@@ -249,10 +252,12 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads
)
print("Create the training dataset.", flush=True)
# print("Create the training dataset.", flush=True)
dataset = load_dataset(flux, args)
dataset.encode_images()
dataset.encode_prompts()
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
@@ -261,16 +266,16 @@ if __name__ == "__main__":
grads = None
losses = []
tic = time.time()
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
print(
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,