# Copyright © 2024 Apple Inc. import argparse import mlx.core as mx import mlx.nn as nn import numpy as np from PIL import Image from tqdm import tqdm from mlx_flux import FluxPipeline def to_latent_size(image_size): h, w = image_size h = ((h + 15) // 16) * 16 w = ((w + 15) // 16) * 16 if (h, w) != image_size: print( "Warning: The image dimensions need to be divisible by 16px. " f"Changing size to {h}x{w}." ) return (h // 8, w // 8) def quantization_predicate(name, m): return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 def load_adapter(flux, adapter_file, fuse=False): weights, lora_config = mx.load(adapter_file, return_metadata=True) rank = int(lora_config["lora_rank"]) num_blocks = int(lora_config["lora_blocks"]) flux.linear_to_lora_layers(rank, num_blocks) flux.flow.load_weights(list(weights.items()), strict=False) if fuse: flux.fuse_lora_layers() if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate images from a textual prompt using stable diffusion" ) parser.add_argument("prompt") parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") parser.add_argument("--n-images", type=int, default=4) parser.add_argument( "--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512) ) parser.add_argument("--steps", type=int) parser.add_argument("--guidance", type=float, default=4.0) parser.add_argument("--n-rows", type=int, default=1) parser.add_argument("--decoding-batch-size", type=int, default=1) parser.add_argument("--quantize", "-q", action="store_true") parser.add_argument("--preload-models", action="store_true") parser.add_argument("--output", default="out.png") parser.add_argument("--save-raw", action="store_true") parser.add_argument("--seed", type=int) parser.add_argument("--verbose", "-v", action="store_true") parser.add_argument("--adapter") parser.add_argument("--fuse-adapter", action="store_true") parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") args = parser.parse_args() # Load the models flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding) args.steps = args.steps or (50 if args.model == "dev" else 2) if args.adapter: load_adapter(flux, args.adapter, fuse=args.fuse_adapter) if args.quantize: nn.quantize(flux.flow, class_predicate=quantization_predicate) nn.quantize(flux.t5, class_predicate=quantization_predicate) nn.quantize(flux.clip, class_predicate=quantization_predicate) if args.preload_models: flux.ensure_models_are_loaded() # Make the generator latent_size = to_latent_size(args.image_size) latents = flux.generate_latents( args.prompt, n_images=args.n_images, num_steps=args.steps, latent_size=latent_size, guidance=args.guidance, seed=args.seed, ) # First we get and eval the conditioning conditioning = next(latents) mx.eval(conditioning) peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 mx.metal.reset_peak_memory() # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the text encoders. del flux.t5 del flux.clip # Actual denoising loop for x_t in tqdm(latents, total=args.steps): mx.eval(x_t) # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the flow transformer. del flux.flow peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 mx.metal.reset_peak_memory() # Decode them into images decoded = [] for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 peak_mem_overall = max( peak_mem_conditioning, peak_mem_generation, peak_mem_decoding ) if args.save_raw: *name, suffix = args.output.split(".") name = ".".join(name) x = mx.concatenate(decoded, axis=0) x = (x * 255).astype(mx.uint8) for i in range(len(x)): im = Image.fromarray(np.array(x[i])) im.save(".".join([name, str(i), suffix])) else: # Arrange them on a grid x = mx.concatenate(decoded, axis=0) x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) B, H, W, C = x.shape x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) x = (x * 255).astype(mx.uint8) # Save them to disc im = Image.fromarray(np.array(x)) im.save(args.output) # Report the peak memory used during generation if args.verbose: print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB") print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")