mlx-examples/flux/generate_interactive.py

110 lines
3.3 KiB
Python
Raw Permalink Normal View History

2025-03-25 13:16:48 +08:00
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def print_zero(group, *args, **kwargs):
if group.rank() == 0:
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
group = mx.distributed.init()
if group.size() > 1:
flux.flow.shard(group)
print_zero(group, "Loading models")
flux.ensure_models_are_loaded()
def print_help():
print_zero(group, "The command list:")
print_zero(group, "- 'q' to exit")
print_zero(group, "- 's HxW' to change the size of the image")
print_zero(group, "- 'n S' to change the number of steps")
print_zero(group, "- 'h' to print this help")
print_zero(group, "FLUX interactive session")
print_help()
seed = 0
size = (512, 512)
latent_size = to_latent_size(size)
steps = 50 if args.model == "dev" else 4
while True:
prompt = input(">> " if group.rank() == 0 else "")
if prompt == "q":
break
if prompt == "h":
print_help()
continue
if prompt.startswith("s "):
size = tuple([int(xi) for xi in prompt[2:].split("x")])
print_zero(group, "Setting the size to", size)
latent_size = to_latent_size(size)
continue
if prompt.startswith("n "):
steps = int(prompt[2:])
print_zero(group, "Setting the steps to", steps)
continue
seed += 1
latents = flux.generate_latents(
prompt,
n_images=1,
num_steps=steps,
latent_size=latent_size,
guidance=4.0,
seed=seed,
)
print_zero(group, "Processing prompt")
mx.eval(next(latents))
print_zero(group, "Generating latents")
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
mx.eval(xt)
print_zero(group, "Generating image")
xt = flux.decode(xt, latent_size)
xt = (xt * 255).astype(mx.uint8)
mx.eval(xt)
im = Image.fromarray(np.array(xt[0]))
im.save(args.output)
print_zero(group, "Saved at", args.output, end="\n\n")