mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Implement Wan2.2
This commit is contained in:
562
video/Wan2.2/wan/utils/fm_solvers.py
Normal file
562
video/Wan2.2/wan/utils/fm_solvers.py
Normal file
@@ -0,0 +1,562 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_sampling_sigmas(sampling_steps, shift):
|
||||
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
||||
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
|
||||
return sigma
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps=None,
|
||||
device=None,
|
||||
timesteps=None,
|
||||
sigmas=None,
|
||||
**kwargs,
|
||||
):
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError(
|
||||
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
||||
)
|
||||
if timesteps is not None:
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SchedulerOutput:
|
||||
"""Output class for scheduler step results."""
|
||||
def __init__(self, prev_sample: mx.array):
|
||||
self.prev_sample = prev_sample
|
||||
|
||||
|
||||
class FlowDPMSolverMultistepScheduler:
|
||||
"""
|
||||
MLX implementation of FlowDPMSolverMultistepScheduler.
|
||||
A fast dedicated high-order solver for diffusion ODEs.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "flow_prediction",
|
||||
shift: Optional[float] = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Optional[str] = "zero",
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
invert_sigmas: bool = False,
|
||||
):
|
||||
# Store configuration
|
||||
self.config = {
|
||||
'num_train_timesteps': num_train_timesteps,
|
||||
'solver_order': solver_order,
|
||||
'prediction_type': prediction_type,
|
||||
'shift': shift,
|
||||
'use_dynamic_shifting': use_dynamic_shifting,
|
||||
'thresholding': thresholding,
|
||||
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
|
||||
'sample_max_value': sample_max_value,
|
||||
'algorithm_type': algorithm_type,
|
||||
'solver_type': solver_type,
|
||||
'lower_order_final': lower_order_final,
|
||||
'euler_at_final': euler_at_final,
|
||||
'final_sigmas_type': final_sigmas_type,
|
||||
'lambda_min_clipped': lambda_min_clipped,
|
||||
'variance_type': variance_type,
|
||||
'invert_sigmas': invert_sigmas,
|
||||
}
|
||||
|
||||
# Validate algorithm type
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type == "deis":
|
||||
self.config['algorithm_type'] = "dpmsolver++"
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} is not implemented")
|
||||
|
||||
# Validate solver type
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||||
self.config['solver_type'] = "midpoint"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} is not implemented")
|
||||
|
||||
# Initialize scheduling
|
||||
self.num_inference_steps = None
|
||||
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = mx.array(sigmas, dtype=mx.float32)
|
||||
|
||||
if not use_dynamic_shifting:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.sigmas = sigmas
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigma_min = float(self.sigmas[-1])
|
||||
self.sigma_max = float(self.sigmas[0])
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Union[int, None] = None,
|
||||
device: Union[str, None] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[Union[float, None]] = None,
|
||||
shift: Optional[Union[float, None]] = None,
|
||||
):
|
||||
"""Sets the discrete timesteps used for the diffusion chain."""
|
||||
if self.config['use_dynamic_shifting'] and mu is None:
|
||||
raise ValueError(
|
||||
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
||||
)
|
||||
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
|
||||
|
||||
if self.config['use_dynamic_shifting']:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
if shift is None:
|
||||
shift = self.config['shift']
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
if self.config['final_sigmas_type'] == "sigma_min":
|
||||
sigma_last = self.sigma_min
|
||||
elif self.config['final_sigmas_type'] == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
|
||||
)
|
||||
|
||||
timesteps = sigmas * self.config['num_train_timesteps']
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(timesteps, dtype=mx.int64)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
self.model_outputs = [None] * self.config['solver_order']
|
||||
self.lower_order_nums = 0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def _threshold_sample(self, sample: mx.array) -> mx.array:
|
||||
"""Dynamic thresholding method."""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
# Flatten sample for quantile calculation
|
||||
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = mx.abs(sample_flat)
|
||||
|
||||
# Compute quantile
|
||||
s = mx.quantile(
|
||||
abs_sample,
|
||||
self.config['dynamic_thresholding_ratio'],
|
||||
axis=1,
|
||||
keepdims=True
|
||||
)
|
||||
s = mx.clip(s, 1, self.config['sample_max_value'])
|
||||
|
||||
# Threshold and normalize
|
||||
sample_flat = mx.clip(sample_flat, -s, s) / s
|
||||
|
||||
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
|
||||
return sample.astype(dtype)
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config['num_train_timesteps']
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
return 1 - sigma, sigma
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: mx.array):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""Convert model output to the corresponding type the algorithm needs."""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model
|
||||
if self.config['algorithm_type'] in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be "
|
||||
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model
|
||||
elif self.config['algorithm_type'] in ["dpmsolver", "sde-dpmsolver"]:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
epsilon = sample - (1 - sigma_t) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be "
|
||||
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = model_output + x0_pred
|
||||
|
||||
return epsilon
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array,
|
||||
noise: Optional[mx.array] = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the first-order DPMSolver (equivalent to DDIM)."""
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s = mx.log(alpha_s) - mx.log(sigma_s)
|
||||
|
||||
h = lambda_t - lambda_s
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mx.exp(-h) - 1.0)) * model_output
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (mx.exp(h) - 1.0)) * model_output
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * model_output +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(alpha_t / alpha_s) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * model_output +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[mx.array],
|
||||
sample: mx.array,
|
||||
noise: Optional[mx.array] = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the second-order multistep DPMSolver."""
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 -
|
||||
0.5 * (alpha_t * (mx.exp(-h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
|
||||
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1
|
||||
)
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
0.5 * (sigma_t * (mx.exp(h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
|
||||
0.5 * (alpha_t * (1 - mx.exp(-2.0 * h))) * D1 +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
|
||||
(alpha_t * ((1.0 - mx.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
2.0 * (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[mx.array],
|
||||
sample: mx.array,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the third-order multistep DPMSolver."""
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
|
||||
lambda_s2 = mx.log(alpha_s2) - mx.log(sigma_s2)
|
||||
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
||||
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
|
||||
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1 -
|
||||
(alpha_t * ((mx.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = mx.where(schedule_timesteps == timestep)[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return int(indices[pos])
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
"""Initialize the step_index counter for the scheduler."""
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep: Union[int, mx.array],
|
||||
sample: mx.array,
|
||||
generator=None,
|
||||
variance_noise: Optional[mx.array] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""Predict the sample from the previous timestep."""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Improve numerical stability for small number of steps
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and
|
||||
(self.config['euler_at_final'] or
|
||||
(self.config['lower_order_final'] and len(self.timesteps) < 15) or
|
||||
self.config['final_sigmas_type'] == "zero")
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and
|
||||
self.config['lower_order_final'] and
|
||||
len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
for i in range(self.config['solver_order'] - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
# Upcast to avoid precision issues
|
||||
sample = sample.astype(mx.float32)
|
||||
|
||||
# Generate noise if needed for SDE variants
|
||||
if self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = mx.random.normal(model_output.shape, dtype=mx.float32)
|
||||
elif self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise.astype(mx.float32)
|
||||
else:
|
||||
noise = None
|
||||
|
||||
if self.config['solver_order'] == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, sample=sample, noise=noise
|
||||
)
|
||||
elif self.config['solver_order'] == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, sample=sample, noise=noise
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, sample=sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config['solver_order']:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# Cast sample back to expected dtype
|
||||
prev_sample = prev_sample.astype(model_output.dtype)
|
||||
|
||||
# Increase step index
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
|
||||
"""Scale model input - no scaling needed for this scheduler."""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: mx.array,
|
||||
noise: mx.array,
|
||||
timesteps: mx.array,
|
||||
) -> mx.array:
|
||||
"""Add noise to original samples."""
|
||||
sigmas = self.sigmas.astype(original_samples.dtype)
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
# Get step indices
|
||||
if self.begin_index is None:
|
||||
step_indices = [
|
||||
self.index_for_timestep(t, schedule_timesteps)
|
||||
for t in timesteps
|
||||
]
|
||||
elif self.step_index is not None:
|
||||
step_indices = [self.step_index] * timesteps.shape[0]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices]
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = mx.expand_dims(sigma, -1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config['num_train_timesteps']
|
||||
Reference in New Issue
Block a user