mlx-examples/video/Wan2.1/wan/utils/fm_solvers.py
2025-07-28 17:07:26 -07:00

562 lines
22 KiB
Python

import math
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
return sigma
def retrieve_timesteps(
scheduler,
num_inference_steps=None,
device=None,
timesteps=None,
sigmas=None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class SchedulerOutput:
"""Output class for scheduler step results."""
def __init__(self, prev_sample: mx.array):
self.prev_sample = prev_sample
class FlowDPMSolverMultistepScheduler:
"""
MLX implementation of FlowDPMSolverMultistepScheduler.
A fast dedicated high-order solver for diffusion ODEs.
"""
order = 1
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting: bool = False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero",
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
invert_sigmas: bool = False,
):
# Store configuration
self.config = {
'num_train_timesteps': num_train_timesteps,
'solver_order': solver_order,
'prediction_type': prediction_type,
'shift': shift,
'use_dynamic_shifting': use_dynamic_shifting,
'thresholding': thresholding,
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
'sample_max_value': sample_max_value,
'algorithm_type': algorithm_type,
'solver_type': solver_type,
'lower_order_final': lower_order_final,
'euler_at_final': euler_at_final,
'final_sigmas_type': final_sigmas_type,
'lambda_min_clipped': lambda_min_clipped,
'variance_type': variance_type,
'invert_sigmas': invert_sigmas,
}
# Validate algorithm type
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.config['algorithm_type'] = "dpmsolver++"
else:
raise NotImplementedError(f"{algorithm_type} is not implemented")
# Validate solver type
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.config['solver_type'] = "midpoint"
else:
raise NotImplementedError(f"{solver_type} is not implemented")
# Initialize scheduling
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = mx.array(sigmas, dtype=mx.float32)
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigma_min = float(self.sigmas[-1])
self.sigma_max = float(self.sigmas[0])
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, None] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""Sets the discrete timesteps used for the diffusion chain."""
if self.config['use_dynamic_shifting'] and mu is None:
raise ValueError(
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
if self.config['use_dynamic_shifting']:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
if shift is None:
shift = self.config['shift']
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
if self.config['final_sigmas_type'] == "sigma_min":
sigma_last = self.sigma_min
elif self.config['final_sigmas_type'] == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
)
timesteps = sigmas * self.config['num_train_timesteps']
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(timesteps, dtype=mx.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config['solver_order']
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
def _threshold_sample(self, sample: mx.array) -> mx.array:
"""Dynamic thresholding method."""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
# Flatten sample for quantile calculation
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = mx.abs(sample_flat)
# Compute quantile
s = mx.quantile(
abs_sample,
self.config['dynamic_thresholding_ratio'],
axis=1,
keepdims=True
)
s = mx.clip(s, 1, self.config['sample_max_value'])
# Threshold and normalize
sample_flat = mx.clip(sample_flat, -s, s) / s
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
return sample.astype(dtype)
def _sigma_to_t(self, sigma):
return sigma * self.config['num_train_timesteps']
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
def time_shift(self, mu: float, sigma: float, t: mx.array):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
def convert_model_output(
self,
model_output: mx.array,
sample: mx.array,
**kwargs,
) -> mx.array:
"""Convert model output to the corresponding type the algorithm needs."""
# DPM-Solver++ needs to solve an integral of the data prediction model
if self.config['algorithm_type'] in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be "
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
)
if self.config['thresholding']:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model
elif self.config['algorithm_type'] in ["dpmsolver", "sde-dpmsolver"]:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be "
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
)
if self.config['thresholding']:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def dpm_solver_first_order_update(
self,
model_output: mx.array,
sample: mx.array,
noise: Optional[mx.array] = None,
**kwargs,
) -> mx.array:
"""One step for the first-order DPMSolver (equivalent to DDIM)."""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s = mx.log(alpha_s) - mx.log(sigma_s)
h = lambda_t - lambda_s
if self.config['algorithm_type'] == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mx.exp(-h) - 1.0)) * model_output
elif self.config['algorithm_type'] == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (mx.exp(h) - 1.0)) * model_output
elif self.config['algorithm_type'] == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * model_output +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['algorithm_type'] == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * model_output +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[mx.array],
sample: mx.array,
noise: Optional[mx.array] = None,
**kwargs,
) -> mx.array:
"""One step for the second-order multistep DPMSolver."""
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config['algorithm_type'] == "dpmsolver++":
if self.config['solver_type'] == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 -
0.5 * (alpha_t * (mx.exp(-h) - 1.0)) * D1
)
elif self.config['solver_type'] == "heun":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config['algorithm_type'] == "dpmsolver":
if self.config['solver_type'] == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
0.5 * (sigma_t * (mx.exp(h) - 1.0)) * D1
)
elif self.config['solver_type'] == "heun":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config['algorithm_type'] == "sde-dpmsolver++":
assert noise is not None
if self.config['solver_type'] == "midpoint":
x_t = (
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
0.5 * (alpha_t * (1 - mx.exp(-2.0 * h))) * D1 +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['solver_type'] == "heun":
x_t = (
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
(alpha_t * ((1.0 - mx.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['algorithm_type'] == "sde-dpmsolver":
assert noise is not None
if self.config['solver_type'] == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * (mx.exp(h) - 1.0)) * D1 +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
elif self.config['solver_type'] == "heun":
x_t = (
(alpha_t / alpha_s0) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
2.0 * (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[mx.array],
sample: mx.array,
**kwargs,
) -> mx.array:
"""One step for the third-order multistep DPMSolver."""
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
lambda_s2 = mx.log(alpha_s2) - mx.log(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config['algorithm_type'] == "dpmsolver++":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1 -
(alpha_t * ((mx.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config['algorithm_type'] == "dpmsolver":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 -
(sigma_t * ((mx.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = mx.where(schedule_timesteps == timestep)[0]
pos = 1 if len(indices) > 1 else 0
return int(indices[pos])
def _init_step_index(self, timestep):
"""Initialize the step_index counter for the scheduler."""
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: mx.array,
timestep: Union[int, mx.array],
sample: mx.array,
generator=None,
variance_noise: Optional[mx.array] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""Predict the sample from the previous timestep."""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# Improve numerical stability for small number of steps
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and
(self.config['euler_at_final'] or
(self.config['lower_order_final'] and len(self.timesteps) < 15) or
self.config['final_sigmas_type'] == "zero")
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and
self.config['lower_order_final'] and
len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config['solver_order'] - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
# Upcast to avoid precision issues
sample = sample.astype(mx.float32)
# Generate noise if needed for SDE variants
if self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = mx.random.normal(model_output.shape, dtype=mx.float32)
elif self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.astype(mx.float32)
else:
noise = None
if self.config['solver_order'] == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(
model_output, sample=sample, noise=noise
)
elif self.config['solver_order'] == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, sample=sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, sample=sample
)
if self.lower_order_nums < self.config['solver_order']:
self.lower_order_nums += 1
# Cast sample back to expected dtype
prev_sample = prev_sample.astype(model_output.dtype)
# Increase step index
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
"""Scale model input - no scaling needed for this scheduler."""
return sample
def add_noise(
self,
original_samples: mx.array,
noise: mx.array,
timesteps: mx.array,
) -> mx.array:
"""Add noise to original samples."""
sigmas = self.sigmas.astype(original_samples.dtype)
schedule_timesteps = self.timesteps
# Get step indices
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices]
while len(sigma.shape) < len(original_samples.shape):
sigma = mx.expand_dims(sigma, -1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config['num_train_timesteps']