FLUX: move cli to mlx_flux dir

This commit is contained in:
madroid 2024-11-07 12:35:49 +08:00
parent 83c92c2a11
commit e61849a003
2 changed files with 29 additions and 15 deletions

View File

@ -1,19 +1,20 @@
# Copyright © 2024 Apple Inc.
import argparse
import time
from functools import partial
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import time
from PIL import Image
from functools import partial
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from pathlib import Path
from mlx_flux import FluxPipeline, Trainer, load_dataset
from .datasets import load_dataset
from .flux import FluxPipeline
from .trainer import Trainer
def generate_progress_images(iteration, flux, args):
@ -186,6 +187,7 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -196,12 +198,14 @@ if __name__ == "__main__":
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -210,6 +214,7 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -225,6 +230,7 @@ if __name__ == "__main__":
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
@ -247,6 +253,7 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads
)
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()
@ -266,7 +273,7 @@ if __name__ == "__main__":
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "

View File

@ -1,14 +1,13 @@
# Copyright © 2024 Apple Inc.
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from mlx_flux import FluxPipeline
from .flux import FluxPipeline
def to_latent_size(image_size):
@ -39,7 +38,7 @@ def load_adapter(flux, adapter_file, fuse=False):
flux.fuse_lora_layers()
if __name__ == "__main__":
def build_parser():
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
@ -62,7 +61,11 @@ if __name__ == "__main__":
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
args = parser.parse_args()
return parser
def main():
args = build_parser().parse_args()
# Load the models
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
@ -93,7 +96,7 @@ if __name__ == "__main__":
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3
mx.metal.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
@ -108,15 +111,15 @@ if __name__ == "__main__":
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3
mx.metal.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
@ -148,3 +151,7 @@ if __name__ == "__main__":
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
if __name__ == "__main__":
main()