import math from typing import List, Optional, Tuple, Union import mlx.core as mx import numpy as np class SchedulerOutput: """Output class for scheduler step results.""" def __init__(self, prev_sample: mx.array): self.prev_sample = prev_sample class FlowUniPCMultistepScheduler: """ MLX implementation of UniPCMultistepScheduler. A training-free framework designed for the fast sampling of diffusion models. """ order = 1 def __init__( self, num_train_timesteps: int = 1000, solver_order: int = 2, prediction_type: str = "flow_prediction", shift: Optional[float] = 1.0, use_dynamic_shifting: bool = False, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p = None, timestep_spacing: str = "linspace", steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", ): # Store configuration self.config = { 'num_train_timesteps': num_train_timesteps, 'solver_order': solver_order, 'prediction_type': prediction_type, 'shift': shift, 'use_dynamic_shifting': use_dynamic_shifting, 'thresholding': thresholding, 'dynamic_thresholding_ratio': dynamic_thresholding_ratio, 'sample_max_value': sample_max_value, 'predict_x0': predict_x0, 'solver_type': solver_type, 'lower_order_final': lower_order_final, 'disable_corrector': disable_corrector, 'solver_p': solver_p, 'timestep_spacing': timestep_spacing, 'steps_offset': steps_offset, 'final_sigmas_type': final_sigmas_type, } # Validate solver type if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: self.config['solver_type'] = "bh2" else: raise NotImplementedError( f"{solver_type} is not implemented for {self.__class__}" ) self.predict_x0 = predict_x0 # setable values self.num_inference_steps = None alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() sigmas = 1.0 - alphas sigmas = mx.array(sigmas, dtype=mx.float32) if not use_dynamic_shifting: sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.sigmas = sigmas self.timesteps = sigmas * num_train_timesteps self.model_outputs = [None] * solver_order self.timestep_list = [None] * solver_order self.lower_order_nums = 0 self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None self._step_index = None self._begin_index = None self.sigma_min = float(self.sigmas[-1]) self.sigma_max = float(self.sigmas[0]) @property def step_index(self): """The index counter for current timestep.""" return self._step_index @property def begin_index(self): """The index for the first timestep.""" return self._begin_index def set_begin_index(self, begin_index: int = 0): """Sets the begin index for the scheduler.""" self._begin_index = begin_index def set_timesteps( self, num_inference_steps: Union[int, None] = None, device: Union[str, None] = None, sigmas: Optional[List[float]] = None, mu: Optional[Union[float, None]] = None, shift: Optional[Union[float, None]] = None, ): """Sets the discrete timesteps used for the diffusion chain.""" if self.config['use_dynamic_shifting'] and mu is None: raise ValueError( "you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" ) if sigmas is None: sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] if self.config['use_dynamic_shifting']: sigmas = self.time_shift(mu, 1.0, sigmas) else: if shift is None: shift = self.config['shift'] sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) if self.config['final_sigmas_type'] == "sigma_min": sigma_last = self.sigma_min elif self.config['final_sigmas_type'] == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}" ) timesteps = sigmas * self.config['num_train_timesteps'] sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = mx.array(sigmas) self.timesteps = mx.array(timesteps, dtype=mx.int64) self.num_inference_steps = len(timesteps) self.model_outputs = [None] * self.config['solver_order'] self.lower_order_nums = 0 self.last_sample = None if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) # add an index counter for schedulers self._step_index = None self._begin_index = None def _threshold_sample(self, sample: mx.array) -> mx.array: """Dynamic thresholding method.""" dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape # Flatten sample for quantile calculation sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = mx.abs(sample_flat) # Compute quantile s = mx.quantile( abs_sample, self.config['dynamic_thresholding_ratio'], axis=1, keepdims=True ) s = mx.clip(s, 1, self.config['sample_max_value']) # Threshold and normalize sample_flat = mx.clip(sample_flat, -s, s) / s sample = sample_flat.reshape(batch_size, channels, *remaining_dims) return sample.astype(dtype) def _sigma_to_t(self, sigma): return sigma * self.config['num_train_timesteps'] def _sigma_to_alpha_sigma_t(self, sigma): return 1 - sigma, sigma def time_shift(self, mu: float, sigma: float, t: mx.array): return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) def convert_model_output( self, model_output: mx.array, sample: mx.array = None, **kwargs, ) -> mx.array: """Convert the model output to the corresponding type the UniPC algorithm needs.""" sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) if self.predict_x0: if self.config['prediction_type'] == "flow_prediction": sigma_t = self.sigmas[self.step_index] x0_pred = sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' " f"for the UniPCMultistepScheduler." ) if self.config['thresholding']: x0_pred = self._threshold_sample(x0_pred) return x0_pred else: if self.config['prediction_type'] == "flow_prediction": sigma_t = self.sigmas[self.step_index] epsilon = sample - (1 - sigma_t) * model_output else: raise ValueError( f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' " f"for the UniPCMultistepScheduler." ) if self.config['thresholding']: sigma_t = self.sigmas[self.step_index] x0_pred = sample - sigma_t * model_output x0_pred = self._threshold_sample(x0_pred) epsilon = model_output + x0_pred return epsilon def multistep_uni_p_bh_update( self, model_output: mx.array, sample: mx.array = None, order: int = None, **kwargs, ) -> mx.array: """One step for the UniP (B(h) version).""" model_output_list = self.model_outputs s0 = self.timestep_list[-1] m0 = model_output_list[-1] x = sample if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = mx.log(alpha_t) - mx.log(sigma_t) lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0) h = lambda_t - lambda_s0 rks = [] D1s = [] for i in range(1, order): si = self.step_index - i mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = mx.log(alpha_si) - mx.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = mx.array(rks) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = mx.exp(hh) - 1 # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config['solver_type'] == "bh1": B_h = hh elif self.config['solver_type'] == "bh2": B_h = mx.exp(hh) - 1 else: raise NotImplementedError() for i in range(1, order + 1): R.append(mx.power(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = mx.stack(R) b = mx.array(b) if len(D1s) > 0: D1s = mx.stack(D1s, axis=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = mx.array([0.5], dtype=x.dtype) else: rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1], stream=mx.cpu).astype(x.dtype) else: D1s = None if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res x_t = x_t.astype(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: mx.array, last_sample: mx.array = None, this_sample: mx.array = None, order: int = None, **kwargs, ) -> mx.array: """One step for the UniC (B(h) version).""" model_output_list = self.model_outputs m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = mx.log(alpha_t) - mx.log(sigma_t) lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0) h = lambda_t - lambda_s0 rks = [] D1s = [] for i in range(1, order): si = self.step_index - (i + 1) mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = mx.log(alpha_si) - mx.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = mx.array(rks) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = mx.exp(hh) - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config['solver_type'] == "bh1": B_h = hh elif self.config['solver_type'] == "bh2": B_h = mx.exp(hh) - 1 else: raise NotImplementedError() for i in range(1, order + 1): R.append(mx.power(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = mx.stack(R) b = mx.array(b) if len(D1s) > 0: D1s = mx.stack(D1s, axis=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = mx.array([0.5], dtype=x.dtype) else: rhos_c = mx.linalg.solve(R, b, stream=mx.cpu).astype(x.dtype) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) x_t = x_t.astype(x.dtype) return x_t def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps condition = schedule_timesteps == timestep indices = mx.argmax(condition.astype(mx.int32)) # Convert scalar to int and return return int(indices) def _init_step_index(self, timestep): """Initialize the step_index counter for the scheduler.""" if self.begin_index is None: self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: mx.array, timestep: Union[int, mx.array], sample: mx.array, return_dict: bool = True, generator=None ) -> Union[SchedulerOutput, Tuple]: """Predict the sample from the previous timestep.""" if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) use_corrector = ( self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None ) model_output_convert = self.convert_model_output( model_output, sample=sample ) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) for i in range(self.config['solver_order'] - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep if self.config['lower_order_final']: this_order = min( self.config['solver_order'], len(self.timesteps) - self.step_index ) else: this_order = self.config['solver_order'] self.this_order = min(this_order, self.lower_order_nums + 1) assert self.this_order > 0 self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, sample=sample, order=self.this_order, ) if self.lower_order_nums < self.config['solver_order']: self.lower_order_nums += 1 # Increase step index self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array: """Scale model input - no scaling needed for this scheduler.""" return sample def add_noise( self, original_samples: mx.array, noise: mx.array, timesteps: mx.array, ) -> mx.array: """Add noise to original samples.""" sigmas = self.sigmas.astype(original_samples.dtype) schedule_timesteps = self.timesteps # Get step indices if self.begin_index is None: step_indices = [ self.index_for_timestep(t, schedule_timesteps) for t in timesteps ] elif self.step_index is not None: step_indices = [self.step_index] * timesteps.shape[0] else: step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices] while len(sigma.shape) < len(original_samples.shape): sigma = mx.expand_dims(sigma, -1) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): return self.config['num_train_timesteps']