FLUX: move cli to mlx_flux dir

This commit is contained in:
madroid 2024-11-07 12:35:49 +08:00
parent 83c92c2a11
commit e61849a003
2 changed files with 29 additions and 15 deletions

View File

@ -1,19 +1,20 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import time
from functools import partial
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
import time
from PIL import Image
from functools import partial
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from pathlib import Path
from mlx_flux import FluxPipeline, Trainer, load_dataset from .datasets import load_dataset
from .flux import FluxPipeline
from .trainer import Trainer
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@ -186,6 +187,7 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule) optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state] state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance): def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -196,12 +198,14 @@ if __name__ == "__main__":
return loss return loss
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)( return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance x, t5_feat, clip_feat, guidance
) )
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -210,6 +214,7 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads) grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads return loss, grads
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -225,6 +230,7 @@ if __name__ == "__main__":
return loss return loss
# We simply route to the appropriate step based on whether we have # We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an # gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step. # update or simply computing and accumulating gradients in this step.
@ -247,6 +253,7 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads x, t5_feat, clip_feat, guidance, prev_grads
) )
dataset = load_dataset(args.dataset) dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args) trainer = Trainer(flux, dataset, args)
trainer.encode_dataset() trainer.encode_dataset()
@ -266,7 +273,7 @@ if __name__ == "__main__":
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
toc = time.time() toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3 peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
print( print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} " f"It/s: {10 / (toc - tic):.3f} "

View File

@ -1,14 +1,13 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from mlx_flux import FluxPipeline from .flux import FluxPipeline
def to_latent_size(image_size): def to_latent_size(image_size):
@ -39,7 +38,7 @@ def load_adapter(flux, adapter_file, fuse=False):
flux.fuse_lora_layers() flux.fuse_lora_layers()
if __name__ == "__main__": def build_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion" description="Generate images from a textual prompt using stable diffusion"
) )
@ -62,7 +61,11 @@ if __name__ == "__main__":
parser.add_argument("--adapter") parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true") parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
args = parser.parse_args() return parser
def main():
args = build_parser().parse_args()
# Load the models # Load the models
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding) flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
@ -93,7 +96,7 @@ if __name__ == "__main__":
# First we get and eval the conditioning # First we get and eval the conditioning
conditioning = next(latents) conditioning = next(latents)
mx.eval(conditioning) mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3
mx.metal.reset_peak_memory() mx.metal.reset_peak_memory()
# The following is not necessary but it may help in memory constrained # The following is not necessary but it may help in memory constrained
@ -108,15 +111,15 @@ if __name__ == "__main__":
# The following is not necessary but it may help in memory constrained # The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer. # systems by reusing the memory kept by the flow transformer.
del flux.flow del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3
mx.metal.reset_peak_memory() mx.metal.reset_peak_memory()
# Decode them into images # Decode them into images
decoded = [] decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1]) mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3
peak_mem_overall = max( peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
) )
@ -148,3 +151,7 @@ if __name__ == "__main__":
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
if __name__ == "__main__":
main()