Fix image2image for SDXL (#563)

---------

Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
This commit is contained in:
devonthomas35 2024-03-11 12:18:47 -07:00 committed by GitHub
parent d0fa6cfcae
commit fe5edee360
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 8 deletions

View File

@ -7,7 +7,7 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
from stable_diffusion import StableDiffusion
from stable_diffusion import StableDiffusion, StableDiffusionXL
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@ -15,10 +15,11 @@ if __name__ == "__main__":
)
parser.add_argument("image")
parser.add_argument("prompt")
parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl")
parser.add_argument("--strength", type=float, default=0.9)
parser.add_argument("--n_images", type=int, default=4)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--cfg", type=float, default=7.5)
parser.add_argument("--steps", type=int)
parser.add_argument("--cfg", type=float)
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
@ -29,10 +30,23 @@ if __name__ == "__main__":
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
sd = StableDiffusion("stabilityai/stable-diffusion-2-1-base", float16=args.float16)
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder_1)
QuantizedLinear.quantize_module(sd.text_encoder_2)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
args.steps = args.steps or 2
else:
sd = StableDiffusion(
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50
if args.preload_models:
sd.ensure_models_are_loaded()
@ -64,6 +78,10 @@ if __name__ == "__main__":
# The following is not necessary but it may help in memory
# constrained systems by reusing the memory kept by the unet and the text
# encoders.
if args.model == "sdxl":
del sd.text_encoder_1
del sd.text_encoder_2
else:
del sd.text_encoder
del sd.unet
del sd.sampler

View File

@ -264,3 +264,42 @@ class StableDiffusionXL(StableDiffusion):
cfg_weight,
text_time=text_time,
)
def generate_latents_from_image(
self,
image,
text: str,
n_images: int = 1,
strength: float = 0.8,
num_steps: int = 2,
cfg_weight: float = 0.0,
negative_text: str = "",
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Define the num steps and start step
start_step = self.sampler.max_time * strength
num_steps = int(num_steps * strength)
# Get the text conditioning
conditioning, pooled_conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
text_time = (
pooled_conditioning,
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
)
# Get the latents from the input image and add noise according to the
# start time.
x_0, _ = self.autoencoder.encode(image[None])
x_0 = mx.broadcast_to(x_0, (n_images,) + x_0.shape[1:])
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
# Perform the denoising loop
yield from self._denoising_loop(
x_T, start_step, conditioning, num_steps, cfg_weight, text_time=text_time
)