mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
138 lines
4.9 KiB
Python
138 lines
4.9 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import argparse
|
|
import math
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
from stable_diffusion import StableDiffusion, StableDiffusionXL
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate images from an image and a textual prompt using stable diffusion"
|
|
)
|
|
parser.add_argument("image")
|
|
parser.add_argument("prompt")
|
|
parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl")
|
|
parser.add_argument("--strength", type=float, default=0.9)
|
|
parser.add_argument("--n_images", type=int, default=4)
|
|
parser.add_argument("--steps", type=int)
|
|
parser.add_argument("--cfg", type=float)
|
|
parser.add_argument("--negative_prompt", default="")
|
|
parser.add_argument("--n_rows", type=int, default=1)
|
|
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
|
parser.add_argument("--quantize", "-q", action="store_true")
|
|
parser.add_argument("--no-float16", dest="float16", action="store_false")
|
|
parser.add_argument("--preload-models", action="store_true")
|
|
parser.add_argument("--output", default="out.png")
|
|
parser.add_argument("--verbose", "-v", action="store_true")
|
|
parser.add_argument("--seed", type=int)
|
|
args = parser.parse_args()
|
|
|
|
# Load the models
|
|
if args.model == "sdxl":
|
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
|
|
|
if args.quantize:
|
|
nn.quantize(
|
|
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
)
|
|
nn.quantize(
|
|
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
)
|
|
|
|
nn.quantize(sd.text_encoder_1)
|
|
nn.quantize(sd.text_encoder_2)
|
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
|
args.cfg = args.cfg or 0.0
|
|
args.steps = args.steps or 2
|
|
else:
|
|
sd = StableDiffusion(
|
|
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
|
)
|
|
if args.quantize:
|
|
nn.quantize(
|
|
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
)
|
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
|
args.cfg = args.cfg or 7.5
|
|
args.steps = args.steps or 50
|
|
|
|
# Fix the steps if they were set too low
|
|
if int(args.steps * args.strength) < 1:
|
|
args.steps = int(math.ceil(1 / args.strength))
|
|
if args.verbose:
|
|
print(
|
|
f"Strength {args.strength} is too low so steps were set to {args.steps}"
|
|
)
|
|
|
|
# Ensure that models are read in memory if needed
|
|
if args.preload_models:
|
|
sd.ensure_models_are_loaded()
|
|
|
|
# Read the image
|
|
img = Image.open(args.image)
|
|
|
|
# Make sure image shape is divisible by 64
|
|
W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
|
if W != img.width or H != img.height:
|
|
print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
|
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
|
|
|
img = mx.array(np.array(img))
|
|
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
|
|
|
# Noise and denoise the latents produced by encoding the img.
|
|
latents = sd.generate_latents_from_image(
|
|
img,
|
|
args.prompt,
|
|
strength=args.strength,
|
|
n_images=args.n_images,
|
|
cfg_weight=args.cfg,
|
|
num_steps=args.steps,
|
|
negative_text=args.negative_prompt,
|
|
seed=args.seed,
|
|
)
|
|
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
|
mx.eval(x_t)
|
|
|
|
# The following is not necessary but it may help in memory
|
|
# constrained systems by reusing the memory kept by the unet and the text
|
|
# encoders.
|
|
if args.model == "sdxl":
|
|
del sd.text_encoder_1
|
|
del sd.text_encoder_2
|
|
else:
|
|
del sd.text_encoder
|
|
del sd.unet
|
|
del sd.sampler
|
|
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
|
|
|
# Decode them into images
|
|
decoded = []
|
|
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
|
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
|
mx.eval(decoded[-1])
|
|
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
|
|
|
# Arrange them on a grid
|
|
x = mx.concatenate(decoded, axis=0)
|
|
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
|
B, H, W, C = x.shape
|
|
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
|
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
|
x = (x * 255).astype(mx.uint8)
|
|
|
|
# Save them to disc
|
|
im = Image.fromarray(np.array(x))
|
|
im.save(args.output)
|
|
|
|
# Report the peak memory used during generation
|
|
if args.verbose:
|
|
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
|
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|