diff --git a/stable_diffusion/README.md b/stable_diffusion/README.md index 5e44cb1a..ab7affa3 100644 --- a/stable_diffusion/README.md +++ b/stable_diffusion/README.md @@ -61,6 +61,24 @@ script in the root of the repository. You can use the script as follows: python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2 +Image 2 Image +------------- + +There is also the option of generating images based on another image using the +example script `image2image.py`. To do that an image is first encoded using the +autoencoder to get its latent representation and then noise is added according +to the forward diffusion process and the `strength` parameter. A `stregnth` of +0.0 means no noise and a `strength` of 1.0 means starting from completely +random noise. + +![image2image](im2im.png) +*Generations with varying strength using the original image and the prompt 'A lit fireplace'.* + +The command to generate the above images is: + + python image2image.py --strength 0.5 original.png 'A lit fireplace' + + Performance ----------- diff --git a/stable_diffusion/im2im.png b/stable_diffusion/im2im.png new file mode 100644 index 00000000..959d208b Binary files /dev/null and b/stable_diffusion/im2im.png differ diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py new file mode 100644 index 00000000..ca325d48 --- /dev/null +++ b/stable_diffusion/image2image.py @@ -0,0 +1,65 @@ +# Copyright © 2023 Apple Inc. + +import argparse + +import mlx.core as mx +import numpy as np +from PIL import Image +from tqdm import tqdm + +from stable_diffusion import StableDiffusion + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate images from an image and a textual prompt using stable diffusion" + ) + parser.add_argument("image") + parser.add_argument("prompt") + parser.add_argument("--strength", type=float, default=0.9) + parser.add_argument("--n_images", type=int, default=4) + parser.add_argument("--steps", type=int, default=50) + parser.add_argument("--cfg", type=float, default=7.5) + parser.add_argument("--negative_prompt", default="") + parser.add_argument("--n_rows", type=int, default=1) + parser.add_argument("--decoding_batch_size", type=int, default=1) + parser.add_argument("--output", default="out.png") + args = parser.parse_args() + + sd = StableDiffusion() + + # Read the image + img = mx.array(np.array(Image.open(args.image))) + img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1 + + # Noise and denoise the latents produced by encoding img. + latents = sd.generate_latents_from_image( + img, + args.prompt, + strength=args.strength, + n_images=args.n_images, + cfg_weight=args.cfg, + num_steps=args.steps, + negative_text=args.negative_prompt, + ) + for x_t in tqdm(latents, total=int(args.steps * args.strength)): + mx.simplify(x_t) + mx.simplify(x_t) + mx.eval(x_t) + + # Decode them into images + decoded = [] + for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): + decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size])) + mx.eval(decoded[-1]) + + # Arrange them on a grid + x = mx.concatenate(decoded, axis=0) + x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(x.__array__()) + im.save(args.output) diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index 899336e8..1d72be41 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -41,20 +41,13 @@ class StableDiffusion: self.sampler = SimpleEulerSampler(self.diffusion_config) self.tokenizer = load_tokenizer(model) - def generate_latents( + def _get_text_conditioning( self, text: str, n_images: int = 1, - num_steps: int = 50, cfg_weight: float = 7.5, negative_text: str = "", - latent_size: Tuple[int] = (64, 64), - seed=None, ): - # Set the PRNG state - seed = seed or int(time.time()) - mx.random.seed(seed) - # Tokenize the text tokens = [self.tokenizer.tokenize(text)] if cfg_weight > 1: @@ -71,27 +64,84 @@ class StableDiffusion: if n_images > 1: conditioning = _repeat(conditioning, n_images, axis=0) + return conditioning + + def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5): + x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t + t_unet = mx.broadcast_to(t, [len(x_t_unet)]) + eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning) + + if cfg_weight > 1: + eps_text, eps_neg = eps_pred.split(2) + eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg) + + x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev) + + return x_t_prev + + def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5): + x_t = x_T + for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype): + x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight) + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 50, + cfg_weight: float = 7.5, + negative_text: str = "", + latent_size: Tuple[int] = (64, 64), + seed=None, + ): + # Set the PRNG state + seed = seed or int(time.time()) + mx.random.seed(seed) + + # Get the text conditioning + conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text) + # Create the latent variables x_T = self.sampler.sample_prior( (n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype ) # Perform the denoising loop - x_t = x_T - for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype): - x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t - t_unet = mx.broadcast_to(t, [len(x_t_unet)]) - eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning) + yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight) - if cfg_weight > 1: - eps_text, eps_neg = eps_pred.split(2) - eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg) + def generate_latents_from_image( + self, + image, + text: str, + n_images: int = 1, + strength: float = 0.8, + num_steps: int = 50, + cfg_weight: float = 7.5, + negative_text: str = "", + seed=None, + ): + # Set the PRNG state + seed = seed or int(time.time()) + mx.random.seed(seed) - x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev) - x_t = x_t_prev - yield x_t + # Define the num steps and start step + start_step = self.sampler.max_time * strength + num_steps = int(num_steps * strength) + + # Get the text conditioning + conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text) + + # Get the latents from the input image and add noise according to the + # start time. + x_0, _ = self.autoencoder.encode(image[None]) + x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:]) + x_T = self.sampler.add_noise(x_0, mx.array(start_step)) + + # Perform the denoising loop + yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight) def decode(self, x_t): - x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor) + x = self.autoencoder.decode(x_t) x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5)) return x diff --git a/stable_diffusion/stable_diffusion/sampler.py b/stable_diffusion/stable_diffusion/sampler.py index 0c7b80ee..41555daf 100644 --- a/stable_diffusion/stable_diffusion/sampler.py +++ b/stable_diffusion/stable_diffusion/sampler.py @@ -49,17 +49,28 @@ class SimpleEulerSampler: [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()] ) + @property + def max_time(self): + return len(self._sigmas) - 1 + def sample_prior(self, shape, dtype=mx.float32, key=None): noise = mx.random.normal(shape, key=key) return ( noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt() ).astype(dtype) + def add_noise(self, x, t, key=None): + noise = mx.random.normal(x.shape, key=key) + s = self.sigmas(t) + return (x + noise * s) * (s.square() + 1).rsqrt() + def sigmas(self, t): return _interp(self._sigmas, t) - def timesteps(self, num_steps: int, dtype=mx.float32): - steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype) + def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32): + start_time = start_time or (len(self._sigmas) - 1) + assert 0 < start_time <= (len(self._sigmas) - 1) + steps = _linspace(start_time, 0, num_steps + 1).astype(dtype) return list(zip(steps, steps[1:])) def step(self, eps_pred, x_t, t, t_prev): diff --git a/stable_diffusion/stable_diffusion/vae.py b/stable_diffusion/stable_diffusion/vae.py index fe473d4c..5fd47f13 100644 --- a/stable_diffusion/stable_diffusion/vae.py +++ b/stable_diffusion/stable_diffusion/vae.py @@ -67,7 +67,7 @@ class EncoderDecoderBlock2D(nn.Module): # Add an optional downsampling layer if add_downsample: self.downsample = nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=2, padding=1 + out_channels, out_channels, kernel_size=3, stride=2, padding=0 ) # or upsampling layer @@ -81,6 +81,7 @@ class EncoderDecoderBlock2D(nn.Module): x = resnet(x) if "downsample" in self: + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) x = self.downsample(x) if "upsample" in self: @@ -253,16 +254,21 @@ class Autoencoder(nn.Module): ) def decode(self, z): + z = z / self.scaling_factor return self.decoder(self.post_quant_proj(z)) - def __call__(self, x, key=None): + def encode(self, x): x = self.encoder(x) x = self.quant_proj(x) - mean, logvar = x.split(2, axis=-1) - std = mx.exp(0.5 * logvar) - z = mx.random.normal(mean.shape, key=key) * std + mean + mean = mean * self.scaling_factor + logvar = logvar + 2 * math.log(self.scaling_factor) + return mean, logvar + + def __call__(self, x, key=None): + mean, logvar = self.encode(x) + z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean x_hat = self.decode(z) return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)