diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index 8c433990..03b41332 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -7,7 +7,7 @@ import numpy as np from PIL import Image from tqdm import tqdm -from stable_diffusion import StableDiffusion +from stable_diffusion import StableDiffusion, StableDiffusionXL if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -15,10 +15,11 @@ if __name__ == "__main__": ) parser.add_argument("image") parser.add_argument("prompt") + parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl") parser.add_argument("--strength", type=float, default=0.9) parser.add_argument("--n_images", type=int, default=4) - parser.add_argument("--steps", type=int, default=50) - parser.add_argument("--cfg", type=float, default=7.5) + parser.add_argument("--steps", type=int) + parser.add_argument("--cfg", type=float) parser.add_argument("--negative_prompt", default="") parser.add_argument("--n_rows", type=int, default=1) parser.add_argument("--decoding_batch_size", type=int, default=1) @@ -29,10 +30,23 @@ if __name__ == "__main__": parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() - sd = StableDiffusion("stabilityai/stable-diffusion-2-1-base", float16=args.float16) - if args.quantize: - QuantizedLinear.quantize_module(sd.text_encoder) - QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) + if args.model == "sdxl": + sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) + if args.quantize: + QuantizedLinear.quantize_module(sd.text_encoder_1) + QuantizedLinear.quantize_module(sd.text_encoder_2) + QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) + args.cfg = args.cfg or 0.0 + args.steps = args.steps or 2 + else: + sd = StableDiffusion( + "stabilityai/stable-diffusion-2-1-base", float16=args.float16 + ) + if args.quantize: + QuantizedLinear.quantize_module(sd.text_encoder) + QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) + args.cfg = args.cfg or 7.5 + args.steps = args.steps or 50 if args.preload_models: sd.ensure_models_are_loaded() @@ -64,7 +78,11 @@ if __name__ == "__main__": # The following is not necessary but it may help in memory # constrained systems by reusing the memory kept by the unet and the text # encoders. - del sd.text_encoder + if args.model == "sdxl": + del sd.text_encoder_1 + del sd.text_encoder_2 + else: + del sd.text_encoder del sd.unet del sd.sampler peak_mem_unet = mx.metal.get_peak_memory() / 1024**3 diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index 802a576a..a4beffb4 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -264,3 +264,42 @@ class StableDiffusionXL(StableDiffusion): cfg_weight, text_time=text_time, ) + + def generate_latents_from_image( + self, + image, + text: str, + n_images: int = 1, + strength: float = 0.8, + num_steps: int = 2, + cfg_weight: float = 0.0, + negative_text: str = "", + seed=None, + ): + # Set the PRNG state + seed = seed or int(time.time()) + mx.random.seed(seed) + + # Define the num steps and start step + start_step = self.sampler.max_time * strength + num_steps = int(num_steps * strength) + + # Get the text conditioning + conditioning, pooled_conditioning = self._get_text_conditioning( + text, n_images, cfg_weight, negative_text + ) + text_time = ( + pooled_conditioning, + mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)), + ) + + # Get the latents from the input image and add noise according to the + # start time. + x_0, _ = self.autoencoder.encode(image[None]) + x_0 = mx.broadcast_to(x_0, (n_images,) + x_0.shape[1:]) + x_T = self.sampler.add_noise(x_0, mx.array(start_step)) + + # Perform the denoising loop + yield from self._denoising_loop( + x_T, start_step, conditioning, num_steps, cfg_weight, text_time=text_time + )