mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add an image2image example in the stable diffusion (#198)
This commit is contained in:
parent
09566c7257
commit
37fd2464dc
@ -61,6 +61,24 @@ script in the root of the repository. You can use the script as follows:
|
||||
|
||||
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
|
||||
|
||||
Image 2 Image
|
||||
-------------
|
||||
|
||||
There is also the option of generating images based on another image using the
|
||||
example script `image2image.py`. To do that an image is first encoded using the
|
||||
autoencoder to get its latent representation and then noise is added according
|
||||
to the forward diffusion process and the `strength` parameter. A `stregnth` of
|
||||
0.0 means no noise and a `strength` of 1.0 means starting from completely
|
||||
random noise.
|
||||
|
||||

|
||||
*Generations with varying strength using the original image and the prompt 'A lit fireplace'.*
|
||||
|
||||
The command to generate the above images is:
|
||||
|
||||
python image2image.py --strength 0.5 original.png 'A lit fireplace'
|
||||
|
||||
|
||||
Performance
|
||||
-----------
|
||||
|
||||
|
BIN
stable_diffusion/im2im.png
Normal file
BIN
stable_diffusion/im2im.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 462 KiB |
65
stable_diffusion/image2image.py
Normal file
65
stable_diffusion/image2image.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from stable_diffusion import StableDiffusion
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from an image and a textual prompt using stable diffusion"
|
||||
)
|
||||
parser.add_argument("image")
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--strength", type=float, default=0.9)
|
||||
parser.add_argument("--n_images", type=int, default=4)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
parser.add_argument("--cfg", type=float, default=7.5)
|
||||
parser.add_argument("--negative_prompt", default="")
|
||||
parser.add_argument("--n_rows", type=int, default=1)
|
||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
sd = StableDiffusion()
|
||||
|
||||
# Read the image
|
||||
img = mx.array(np.array(Image.open(args.image)))
|
||||
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||
|
||||
# Noise and denoise the latents produced by encoding img.
|
||||
latents = sd.generate_latents_from_image(
|
||||
img,
|
||||
args.prompt,
|
||||
strength=args.strength,
|
||||
n_images=args.n_images,
|
||||
cfg_weight=args.cfg,
|
||||
num_steps=args.steps,
|
||||
negative_text=args.negative_prompt,
|
||||
)
|
||||
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
||||
mx.simplify(x_t)
|
||||
mx.simplify(x_t)
|
||||
mx.eval(x_t)
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||
mx.eval(decoded[-1])
|
||||
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(x.__array__())
|
||||
im.save(args.output)
|
@ -41,20 +41,13 @@ class StableDiffusion:
|
||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||
self.tokenizer = load_tokenizer(model)
|
||||
|
||||
def generate_latents(
|
||||
def _get_text_conditioning(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
latent_size: Tuple[int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Tokenize the text
|
||||
tokens = [self.tokenizer.tokenize(text)]
|
||||
if cfg_weight > 1:
|
||||
@ -71,27 +64,84 @@ class StableDiffusion:
|
||||
if n_images > 1:
|
||||
conditioning = _repeat(conditioning, n_images, axis=0)
|
||||
|
||||
return conditioning
|
||||
|
||||
def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5):
|
||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = eps_pred.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
|
||||
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
|
||||
|
||||
return x_t_prev
|
||||
|
||||
def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5):
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype):
|
||||
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
latent_size: Tuple[int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Get the text conditioning
|
||||
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
|
||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||
yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight)
|
||||
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = eps_pred.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
def generate_latents_from_image(
|
||||
self,
|
||||
image,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
strength: float = 0.8,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
|
||||
x_t = x_t_prev
|
||||
yield x_t
|
||||
# Define the num steps and start step
|
||||
start_step = self.sampler.max_time * strength
|
||||
num_steps = int(num_steps * strength)
|
||||
|
||||
# Get the text conditioning
|
||||
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
|
||||
|
||||
# Get the latents from the input image and add noise according to the
|
||||
# start time.
|
||||
x_0, _ = self.autoencoder.encode(image[None])
|
||||
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
|
||||
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
|
||||
|
||||
# Perform the denoising loop
|
||||
yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight)
|
||||
|
||||
def decode(self, x_t):
|
||||
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
|
||||
x = self.autoencoder.decode(x_t)
|
||||
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
||||
return x
|
||||
|
@ -49,17 +49,28 @@ class SimpleEulerSampler:
|
||||
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||
)
|
||||
|
||||
@property
|
||||
def max_time(self):
|
||||
return len(self._sigmas) - 1
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
noise = mx.random.normal(shape, key=key)
|
||||
return (
|
||||
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||
).astype(dtype)
|
||||
|
||||
def add_noise(self, x, t, key=None):
|
||||
noise = mx.random.normal(x.shape, key=key)
|
||||
s = self.sigmas(t)
|
||||
return (x + noise * s) * (s.square() + 1).rsqrt()
|
||||
|
||||
def sigmas(self, t):
|
||||
return _interp(self._sigmas, t)
|
||||
|
||||
def timesteps(self, num_steps: int, dtype=mx.float32):
|
||||
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
|
||||
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
|
||||
start_time = start_time or (len(self._sigmas) - 1)
|
||||
assert 0 < start_time <= (len(self._sigmas) - 1)
|
||||
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
|
||||
return list(zip(steps, steps[1:]))
|
||||
|
||||
def step(self, eps_pred, x_t, t, t_prev):
|
||||
|
@ -67,7 +67,7 @@ class EncoderDecoderBlock2D(nn.Module):
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
@ -81,6 +81,7 @@ class EncoderDecoderBlock2D(nn.Module):
|
||||
x = resnet(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
x = self.downsample(x)
|
||||
|
||||
if "upsample" in self:
|
||||
@ -253,16 +254,21 @@ class Autoencoder(nn.Module):
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scaling_factor
|
||||
return self.decoder(self.post_quant_proj(z))
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
def encode(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.quant_proj(x)
|
||||
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
std = mx.exp(0.5 * logvar)
|
||||
z = mx.random.normal(mean.shape, key=key) * std + mean
|
||||
mean = mean * self.scaling_factor
|
||||
logvar = logvar + 2 * math.log(self.scaling_factor)
|
||||
|
||||
return mean, logvar
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
mean, logvar = self.encode(x)
|
||||
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
||||
|
Loading…
Reference in New Issue
Block a user