mlx-examples/stable_diffusion/stable_diffusion/sampler.py
2024-03-08 10:24:19 -08:00

106 lines
3.3 KiB
Python

# Copyright © 2023 Apple Inc.
import mlx.core as mx
from .config import DiffusionConfig
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)
return (b - a) * x + a
def _interp(y, x_new):
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
x_low = x_new.astype(mx.int32)
x_high = mx.minimum(x_low + 1, len(y) - 1)
y_low = y[x_low]
y_high = y[x_high]
delta_x = x_new - x_low
y_new = y_low * (1 - delta_x) + delta_x * y_high
return y_new
class SimpleEulerSampler:
"""A simple Euler integrator that can be used to sample from our diffusion models.
The method ``step()`` performs one Euler step from x_t to x_t_prev.
"""
def __init__(self, config: DiffusionConfig):
# Compute the noise schedule
if config.beta_schedule == "linear":
betas = _linspace(
config.beta_start, config.beta_end, config.num_train_steps
)
elif config.beta_schedule == "scaled_linear":
betas = _linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
).square()
else:
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
alphas = 1 - betas
alphas_cumprod = mx.cumprod(alphas)
self._sigmas = mx.concatenate(
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
@property
def max_time(self):
return len(self._sigmas) - 1
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def add_noise(self, x, t, key=None):
noise = mx.random.normal(x.shape, key=key)
s = self.sigmas(t)
return (x + noise * s) * (s.square() + 1).rsqrt()
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
start_time = start_time or (len(self._sigmas) - 1)
assert 0 < start_time <= (len(self._sigmas) - 1)
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
dt = sigma_prev - sigma
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev
class SimpleEulerAncestralSampler(SimpleEulerSampler):
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
sigma2 = sigma.square()
sigma_prev2 = sigma_prev.square()
sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
dt = sigma_down - sigma
x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
x_t_prev = x_t_prev + noise * sigma_up
x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
return x_t_prev