Add an image2image example in the stable diffusion (#198)

This commit is contained in:
Angelos Katharopoulos
2023-12-28 18:31:45 -08:00
committed by GitHub
parent 09566c7257
commit 37fd2464dc
6 changed files with 177 additions and 27 deletions

View File

@@ -41,20 +41,13 @@ class StableDiffusion:
self.sampler = SimpleEulerSampler(self.diffusion_config)
self.tokenizer = load_tokenizer(model)
def generate_latents(
def _get_text_conditioning(
self,
text: str,
n_images: int = 1,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Tokenize the text
tokens = [self.tokenizer.tokenize(text)]
if cfg_weight > 1:
@@ -71,27 +64,84 @@ class StableDiffusion:
if n_images > 1:
conditioning = _repeat(conditioning, n_images, axis=0)
return conditioning
def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
return x_t_prev
def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5):
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype):
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
)
# Perform the denoising loop
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
def generate_latents_from_image(
self,
image,
text: str,
n_images: int = 1,
strength: float = 0.8,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
x_t = x_t_prev
yield x_t
# Define the num steps and start step
start_step = self.sampler.max_time * strength
num_steps = int(num_steps * strength)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
# Get the latents from the input image and add noise according to the
# start time.
x_0, _ = self.autoencoder.encode(image[None])
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
# Perform the denoising loop
yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight)
def decode(self, x_t):
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
x = self.autoencoder.decode(x_t)
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
return x

View File

@@ -49,17 +49,28 @@ class SimpleEulerSampler:
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
@property
def max_time(self):
return len(self._sigmas) - 1
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def add_noise(self, x, t, key=None):
noise = mx.random.normal(x.shape, key=key)
s = self.sigmas(t)
return (x + noise * s) * (s.square() + 1).rsqrt()
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, dtype=mx.float32):
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
start_time = start_time or (len(self._sigmas) - 1)
assert 0 < start_time <= (len(self._sigmas) - 1)
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def step(self, eps_pred, x_t, t, t_prev):

View File

@@ -67,7 +67,7 @@ class EncoderDecoderBlock2D(nn.Module):
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
out_channels, out_channels, kernel_size=3, stride=2, padding=0
)
# or upsampling layer
@@ -81,6 +81,7 @@ class EncoderDecoderBlock2D(nn.Module):
x = resnet(x)
if "downsample" in self:
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.downsample(x)
if "upsample" in self:
@@ -253,16 +254,21 @@ class Autoencoder(nn.Module):
)
def decode(self, z):
z = z / self.scaling_factor
return self.decoder(self.post_quant_proj(z))
def __call__(self, x, key=None):
def encode(self, x):
x = self.encoder(x)
x = self.quant_proj(x)
mean, logvar = x.split(2, axis=-1)
std = mx.exp(0.5 * logvar)
z = mx.random.normal(mean.shape, key=key) * std + mean
mean = mean * self.scaling_factor
logvar = logvar + 2 * math.log(self.scaling_factor)
return mean, logvar
def __call__(self, x, key=None):
mean, logvar = self.encode(x)
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)