From e61849a0032b2819961956bc62fd62dd34bb6199 Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 7 Nov 2024 12:35:49 +0800 Subject: [PATCH] FLUX: move cli to mlx_flux dir --- flux/{ => mlx_flux}/dreambooth.py | 21 ++++++++++++++------- flux/{ => mlx_flux}/txt2image.py | 23 +++++++++++++++-------- 2 files changed, 29 insertions(+), 15 deletions(-) rename flux/{ => mlx_flux}/dreambooth.py (98%) rename flux/{ => mlx_flux}/txt2image.py (92%) diff --git a/flux/dreambooth.py b/flux/mlx_flux/dreambooth.py similarity index 98% rename from flux/dreambooth.py rename to flux/mlx_flux/dreambooth.py index 9dcaffb3..8327af09 100644 --- a/flux/dreambooth.py +++ b/flux/mlx_flux/dreambooth.py @@ -1,19 +1,20 @@ # Copyright © 2024 Apple Inc. import argparse -import time -from functools import partial -from pathlib import Path - import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np +import time +from PIL import Image +from functools import partial from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce -from PIL import Image +from pathlib import Path -from mlx_flux import FluxPipeline, Trainer, load_dataset +from .datasets import load_dataset +from .flux import FluxPipeline +from .trainer import Trainer def generate_progress_images(iteration, flux, args): @@ -186,6 +187,7 @@ if __name__ == "__main__": optimizer = optim.Adam(learning_rate=lr_schedule) state = [flux.flow.state, optimizer.state, mx.random.state] + @partial(mx.compile, inputs=state, outputs=state) def single_step(x, t5_feat, clip_feat, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -196,12 +198,14 @@ if __name__ == "__main__": return loss + @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): return nn.value_and_grad(flux.flow, flux.training_loss)( x, t5_feat, clip_feat, guidance ) + @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -210,6 +214,7 @@ if __name__ == "__main__": grads = tree_map(lambda a, b: a + b, prev_grads, grads) return loss, grads + @partial(mx.compile, inputs=state, outputs=state) def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -225,6 +230,7 @@ if __name__ == "__main__": return loss + # We simply route to the appropriate step based on whether we have # gradients from a previous step and whether we should be performing an # update or simply computing and accumulating gradients in this step. @@ -247,6 +253,7 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) + dataset = load_dataset(args.dataset) trainer = Trainer(flux, dataset, args) trainer.encode_dataset() @@ -266,7 +273,7 @@ if __name__ == "__main__": if (i + 1) % 10 == 0: toc = time.time() - peak_mem = mx.metal.get_peak_memory() / 1024**3 + peak_mem = mx.metal.get_peak_memory() / 1024 ** 3 print( f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " diff --git a/flux/txt2image.py b/flux/mlx_flux/txt2image.py similarity index 92% rename from flux/txt2image.py rename to flux/mlx_flux/txt2image.py index fc209b17..358e8cad 100644 --- a/flux/txt2image.py +++ b/flux/mlx_flux/txt2image.py @@ -1,14 +1,13 @@ # Copyright © 2024 Apple Inc. import argparse - import mlx.core as mx import mlx.nn as nn import numpy as np from PIL import Image from tqdm import tqdm -from mlx_flux import FluxPipeline +from .flux import FluxPipeline def to_latent_size(image_size): @@ -39,7 +38,7 @@ def load_adapter(flux, adapter_file, fuse=False): flux.fuse_lora_layers() -if __name__ == "__main__": +def build_parser(): parser = argparse.ArgumentParser( description="Generate images from a textual prompt using stable diffusion" ) @@ -62,7 +61,11 @@ if __name__ == "__main__": parser.add_argument("--adapter") parser.add_argument("--fuse-adapter", action="store_true") parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") - args = parser.parse_args() + return parser + + +def main(): + args = build_parser().parse_args() # Load the models flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding) @@ -93,7 +96,7 @@ if __name__ == "__main__": # First we get and eval the conditioning conditioning = next(latents) mx.eval(conditioning) - peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3 mx.metal.reset_peak_memory() # The following is not necessary but it may help in memory constrained @@ -108,15 +111,15 @@ if __name__ == "__main__": # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the flow transformer. del flux.flow - peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 + peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3 mx.metal.reset_peak_memory() # Decode them into images decoded = [] for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): - decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) + decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) - peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 + peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3 peak_mem_overall = max( peak_mem_conditioning, peak_mem_generation, peak_mem_decoding ) @@ -148,3 +151,7 @@ if __name__ == "__main__": print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") + + +if __name__ == "__main__": + main()