Add an image2image example in the stable diffusion (#198)

This commit is contained in:
Angelos Katharopoulos 2023-12-28 18:31:45 -08:00 committed by GitHub
parent 09566c7257
commit 37fd2464dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 177 additions and 27 deletions

View File

@ -61,6 +61,24 @@ script in the root of the repository. You can use the script as follows:
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
Image 2 Image
-------------
There is also the option of generating images based on another image using the
example script `image2image.py`. To do that an image is first encoded using the
autoencoder to get its latent representation and then noise is added according
to the forward diffusion process and the `strength` parameter. A `stregnth` of
0.0 means no noise and a `strength` of 1.0 means starting from completely
random noise.
![image2image](im2im.png)
*Generations with varying strength using the original image and the prompt 'A lit fireplace'.*
The command to generate the above images is:
python image2image.py --strength 0.5 original.png 'A lit fireplace'
Performance
-----------

BIN
stable_diffusion/im2im.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 462 KiB

View File

@ -0,0 +1,65 @@
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
import numpy as np
from PIL import Image
from tqdm import tqdm
from stable_diffusion import StableDiffusion
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from an image and a textual prompt using stable diffusion"
)
parser.add_argument("image")
parser.add_argument("prompt")
parser.add_argument("--strength", type=float, default=0.9)
parser.add_argument("--n_images", type=int, default=4)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--cfg", type=float, default=7.5)
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
sd = StableDiffusion()
# Read the image
img = mx.array(np.array(Image.open(args.image)))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
# Noise and denoise the latents produced by encoding img.
latents = sd.generate_latents_from_image(
img,
args.prompt,
strength=args.strength,
n_images=args.n_images,
cfg_weight=args.cfg,
num_steps=args.steps,
negative_text=args.negative_prompt,
)
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
mx.simplify(x_t)
mx.simplify(x_t)
mx.eval(x_t)
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(x.__array__())
im.save(args.output)

View File

@ -41,20 +41,13 @@ class StableDiffusion:
self.sampler = SimpleEulerSampler(self.diffusion_config)
self.tokenizer = load_tokenizer(model)
def generate_latents(
def _get_text_conditioning(
self,
text: str,
n_images: int = 1,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Tokenize the text
tokens = [self.tokenizer.tokenize(text)]
if cfg_weight > 1:
@ -71,27 +64,84 @@ class StableDiffusion:
if n_images > 1:
conditioning = _repeat(conditioning, n_images, axis=0)
return conditioning
def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
return x_t_prev
def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5):
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype):
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
)
# Perform the denoising loop
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
def generate_latents_from_image(
self,
image,
text: str,
n_images: int = 1,
strength: float = 0.8,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
x_t = x_t_prev
yield x_t
# Define the num steps and start step
start_step = self.sampler.max_time * strength
num_steps = int(num_steps * strength)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
# Get the latents from the input image and add noise according to the
# start time.
x_0, _ = self.autoencoder.encode(image[None])
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
# Perform the denoising loop
yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight)
def decode(self, x_t):
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
x = self.autoencoder.decode(x_t)
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
return x

View File

@ -49,17 +49,28 @@ class SimpleEulerSampler:
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
@property
def max_time(self):
return len(self._sigmas) - 1
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def add_noise(self, x, t, key=None):
noise = mx.random.normal(x.shape, key=key)
s = self.sigmas(t)
return (x + noise * s) * (s.square() + 1).rsqrt()
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, dtype=mx.float32):
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
start_time = start_time or (len(self._sigmas) - 1)
assert 0 < start_time <= (len(self._sigmas) - 1)
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def step(self, eps_pred, x_t, t, t_prev):

View File

@ -67,7 +67,7 @@ class EncoderDecoderBlock2D(nn.Module):
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
out_channels, out_channels, kernel_size=3, stride=2, padding=0
)
# or upsampling layer
@ -81,6 +81,7 @@ class EncoderDecoderBlock2D(nn.Module):
x = resnet(x)
if "downsample" in self:
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.downsample(x)
if "upsample" in self:
@ -253,16 +254,21 @@ class Autoencoder(nn.Module):
)
def decode(self, z):
z = z / self.scaling_factor
return self.decoder(self.post_quant_proj(z))
def __call__(self, x, key=None):
def encode(self, x):
x = self.encoder(x)
x = self.quant_proj(x)
mean, logvar = x.split(2, axis=-1)
std = mx.exp(0.5 * logvar)
z = mx.random.normal(mean.shape, key=key) * std + mean
mean = mean * self.scaling_factor
logvar = logvar + 2 * math.log(self.scaling_factor)
return mean, logvar
def __call__(self, x, key=None):
mean, logvar = self.encode(x)
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)