2024-10-12 12:17:41 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
from flux import FluxPipeline
|
|
|
|
|
|
|
|
|
|
|
|
def to_latent_size(image_size):
|
|
|
|
h, w = image_size
|
|
|
|
h = ((h + 15) // 16) * 16
|
|
|
|
w = ((w + 15) // 16) * 16
|
|
|
|
|
|
|
|
if (h, w) != image_size:
|
|
|
|
print(
|
|
|
|
"Warning: The image dimensions need to be divisible by 16px. "
|
|
|
|
f"Changing size to {h}x{w}."
|
|
|
|
)
|
|
|
|
|
|
|
|
return (h // 8, w // 8)
|
|
|
|
|
|
|
|
|
|
|
|
def quantization_predicate(name, m):
|
|
|
|
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
|
|
|
|
|
|
|
|
|
|
|
def load_adapter(flux, adapter_file, fuse=False):
|
|
|
|
weights, lora_config = mx.load(adapter_file, return_metadata=True)
|
|
|
|
rank = int(lora_config["lora_rank"])
|
|
|
|
num_blocks = int(lora_config["lora_blocks"])
|
|
|
|
flux.linear_to_lora_layers(rank, num_blocks)
|
|
|
|
flux.flow.load_weights(list(weights.items()), strict=False)
|
|
|
|
if fuse:
|
|
|
|
flux.fuse_lora_layers()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(
|
2025-03-25 13:16:48 +08:00
|
|
|
description="Generate images from a textual prompt using FLUX"
|
2024-10-12 12:17:41 +08:00
|
|
|
)
|
|
|
|
parser.add_argument("prompt")
|
|
|
|
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
|
|
|
parser.add_argument("--n-images", type=int, default=4)
|
|
|
|
parser.add_argument(
|
|
|
|
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
|
|
|
)
|
|
|
|
parser.add_argument("--steps", type=int)
|
|
|
|
parser.add_argument("--guidance", type=float, default=4.0)
|
|
|
|
parser.add_argument("--n-rows", type=int, default=1)
|
|
|
|
parser.add_argument("--decoding-batch-size", type=int, default=1)
|
|
|
|
parser.add_argument("--quantize", "-q", action="store_true")
|
|
|
|
parser.add_argument("--preload-models", action="store_true")
|
|
|
|
parser.add_argument("--output", default="out.png")
|
|
|
|
parser.add_argument("--save-raw", action="store_true")
|
|
|
|
parser.add_argument("--seed", type=int)
|
|
|
|
parser.add_argument("--verbose", "-v", action="store_true")
|
|
|
|
parser.add_argument("--adapter")
|
|
|
|
parser.add_argument("--fuse-adapter", action="store_true")
|
|
|
|
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
2025-03-25 13:16:48 +08:00
|
|
|
parser.add_argument("--force-shard", action="store_true")
|
2024-10-12 12:17:41 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# Load the models
|
|
|
|
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
|
|
|
|
args.steps = args.steps or (50 if args.model == "dev" else 2)
|
|
|
|
|
|
|
|
if args.adapter:
|
|
|
|
load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
|
|
|
|
|
|
|
|
if args.quantize:
|
|
|
|
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
|
|
|
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
|
|
|
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
|
|
|
|
2025-03-25 13:16:48 +08:00
|
|
|
# Figure out what kind of distributed generation we should do
|
|
|
|
group = mx.distributed.init()
|
|
|
|
n_images = args.n_images
|
|
|
|
should_gather = False
|
|
|
|
if group.size() > 1:
|
|
|
|
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
|
|
|
flux.flow.shard(group)
|
|
|
|
else:
|
|
|
|
n_images //= group.size()
|
|
|
|
should_gather = True
|
|
|
|
|
|
|
|
# If we are sharding we should have the same seed and if we are doing
|
|
|
|
# data parallel generation we should have different seeds
|
|
|
|
if args.seed is None:
|
|
|
|
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
|
|
|
if should_gather:
|
|
|
|
args.seed = args.seed + group.rank()
|
|
|
|
|
2024-10-12 12:17:41 +08:00
|
|
|
if args.preload_models:
|
2024-10-16 00:13:01 +08:00
|
|
|
flux.ensure_models_are_loaded()
|
2024-10-12 12:17:41 +08:00
|
|
|
|
|
|
|
# Make the generator
|
|
|
|
latent_size = to_latent_size(args.image_size)
|
|
|
|
latents = flux.generate_latents(
|
|
|
|
args.prompt,
|
2025-03-25 13:16:48 +08:00
|
|
|
n_images=n_images,
|
2024-10-12 12:17:41 +08:00
|
|
|
num_steps=args.steps,
|
|
|
|
latent_size=latent_size,
|
|
|
|
guidance=args.guidance,
|
|
|
|
seed=args.seed,
|
|
|
|
)
|
|
|
|
|
|
|
|
# First we get and eval the conditioning
|
|
|
|
conditioning = next(latents)
|
|
|
|
mx.eval(conditioning)
|
2025-03-25 13:16:48 +08:00
|
|
|
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
|
|
|
|
mx.reset_peak_memory()
|
2024-10-12 12:17:41 +08:00
|
|
|
|
|
|
|
# The following is not necessary but it may help in memory constrained
|
|
|
|
# systems by reusing the memory kept by the text encoders.
|
|
|
|
del flux.t5
|
|
|
|
del flux.clip
|
|
|
|
|
|
|
|
# Actual denoising loop
|
2025-03-25 13:16:48 +08:00
|
|
|
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
|
2024-10-12 12:17:41 +08:00
|
|
|
mx.eval(x_t)
|
|
|
|
|
|
|
|
# The following is not necessary but it may help in memory constrained
|
|
|
|
# systems by reusing the memory kept by the flow transformer.
|
|
|
|
del flux.flow
|
2025-03-25 13:16:48 +08:00
|
|
|
peak_mem_generation = mx.get_peak_memory() / 1024**3
|
|
|
|
mx.reset_peak_memory()
|
2024-10-12 12:17:41 +08:00
|
|
|
|
|
|
|
# Decode them into images
|
|
|
|
decoded = []
|
2025-03-25 13:16:48 +08:00
|
|
|
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
|
2024-10-12 12:17:41 +08:00
|
|
|
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
|
|
|
mx.eval(decoded[-1])
|
2025-03-25 13:16:48 +08:00
|
|
|
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
2024-10-12 12:17:41 +08:00
|
|
|
peak_mem_overall = max(
|
|
|
|
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
|
|
|
)
|
|
|
|
|
2025-03-25 13:16:48 +08:00
|
|
|
# Gather them if each node has different images
|
|
|
|
decoded = mx.concatenate(decoded, axis=0)
|
|
|
|
if should_gather:
|
|
|
|
decoded = mx.distributed.all_gather(decoded)
|
|
|
|
mx.eval(decoded)
|
|
|
|
|
2024-10-12 12:17:41 +08:00
|
|
|
if args.save_raw:
|
|
|
|
*name, suffix = args.output.split(".")
|
|
|
|
name = ".".join(name)
|
2025-03-25 13:16:48 +08:00
|
|
|
x = decoded
|
2024-10-12 12:17:41 +08:00
|
|
|
x = (x * 255).astype(mx.uint8)
|
|
|
|
for i in range(len(x)):
|
|
|
|
im = Image.fromarray(np.array(x[i]))
|
|
|
|
im.save(".".join([name, str(i), suffix]))
|
|
|
|
else:
|
|
|
|
# Arrange them on a grid
|
2025-03-25 13:16:48 +08:00
|
|
|
x = decoded
|
2024-10-12 12:17:41 +08:00
|
|
|
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
|
|
|
B, H, W, C = x.shape
|
|
|
|
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
|
|
|
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
|
|
|
x = (x * 255).astype(mx.uint8)
|
|
|
|
|
|
|
|
# Save them to disc
|
|
|
|
im = Image.fromarray(np.array(x))
|
|
|
|
im.save(args.output)
|
|
|
|
|
|
|
|
# Report the peak memory used during generation
|
2025-03-25 13:16:48 +08:00
|
|
|
if args.verbose and group.rank() == 0:
|
2024-10-12 12:17:41 +08:00
|
|
|
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
|
|
|
|
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
|
|
|
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
|
|
|
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|