mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
106 lines
3.3 KiB
Python
106 lines
3.3 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import mlx.core as mx
|
|
|
|
from .config import DiffusionConfig
|
|
|
|
|
|
def _linspace(a, b, num):
|
|
x = mx.arange(0, num) / (num - 1)
|
|
return (b - a) * x + a
|
|
|
|
|
|
def _interp(y, x_new):
|
|
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
|
x_low = x_new.astype(mx.int32)
|
|
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
|
|
|
y_low = y[x_low]
|
|
y_high = y[x_high]
|
|
delta_x = x_new - x_low
|
|
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
|
|
|
return y_new
|
|
|
|
|
|
class SimpleEulerSampler:
|
|
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
|
|
|
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
|
"""
|
|
|
|
def __init__(self, config: DiffusionConfig):
|
|
# Compute the noise schedule
|
|
if config.beta_schedule == "linear":
|
|
betas = _linspace(
|
|
config.beta_start, config.beta_end, config.num_train_steps
|
|
)
|
|
elif config.beta_schedule == "scaled_linear":
|
|
betas = _linspace(
|
|
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
|
).square()
|
|
else:
|
|
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
|
|
|
alphas = 1 - betas
|
|
alphas_cumprod = mx.cumprod(alphas)
|
|
|
|
self._sigmas = mx.concatenate(
|
|
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
|
)
|
|
|
|
@property
|
|
def max_time(self):
|
|
return len(self._sigmas) - 1
|
|
|
|
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
|
noise = mx.random.normal(shape, key=key)
|
|
return (
|
|
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
|
).astype(dtype)
|
|
|
|
def add_noise(self, x, t, key=None):
|
|
noise = mx.random.normal(x.shape, key=key)
|
|
s = self.sigmas(t)
|
|
return (x + noise * s) * (s.square() + 1).rsqrt()
|
|
|
|
def sigmas(self, t):
|
|
return _interp(self._sigmas, t)
|
|
|
|
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
|
|
start_time = start_time or (len(self._sigmas) - 1)
|
|
assert 0 < start_time <= (len(self._sigmas) - 1)
|
|
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
|
|
return list(zip(steps, steps[1:]))
|
|
|
|
def step(self, eps_pred, x_t, t, t_prev):
|
|
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
|
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
|
|
|
dt = sigma_prev - sigma
|
|
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
|
|
|
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
|
|
|
return x_t_prev
|
|
|
|
|
|
class SimpleEulerAncestralSampler(SimpleEulerSampler):
|
|
def step(self, eps_pred, x_t, t, t_prev):
|
|
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
|
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
|
|
|
sigma2 = sigma.square()
|
|
sigma_prev2 = sigma_prev.square()
|
|
sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
|
|
sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
|
|
|
|
dt = sigma_down - sigma
|
|
x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
|
|
noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
|
|
x_t_prev = x_t_prev + noise * sigma_up
|
|
|
|
x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
|
|
|
|
return x_t_prev
|