mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
@@ -79,9 +79,13 @@ class StableDiffusion:
|
||||
|
||||
return x_t_prev
|
||||
|
||||
def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5):
|
||||
def _denoising_loop(
|
||||
self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5
|
||||
):
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype):
|
||||
for t, t_prev in self.sampler.timesteps(
|
||||
num_steps, start_time=T, dtype=self.dtype
|
||||
):
|
||||
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
|
||||
yield x_t
|
||||
|
||||
@@ -100,7 +104,9 @@ class StableDiffusion:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Get the text conditioning
|
||||
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
|
||||
conditioning = self._get_text_conditioning(
|
||||
text, n_images, cfg_weight, negative_text
|
||||
)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
@@ -108,7 +114,9 @@ class StableDiffusion:
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight)
|
||||
yield from self._denoising_loop(
|
||||
x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight
|
||||
)
|
||||
|
||||
def generate_latents_from_image(
|
||||
self,
|
||||
@@ -130,16 +138,20 @@ class StableDiffusion:
|
||||
num_steps = int(num_steps * strength)
|
||||
|
||||
# Get the text conditioning
|
||||
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
|
||||
conditioning = self._get_text_conditioning(
|
||||
text, n_images, cfg_weight, negative_text
|
||||
)
|
||||
|
||||
# Get the latents from the input image and add noise according to the
|
||||
# start time.
|
||||
x_0, _ = self.autoencoder.encode(image[None])
|
||||
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
|
||||
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
|
||||
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
|
||||
|
||||
# Perform the denoising loop
|
||||
yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight)
|
||||
yield from self._denoising_loop(
|
||||
x_T, start_step, conditioning, num_steps, cfg_weight
|
||||
)
|
||||
|
||||
def decode(self, x_t):
|
||||
x = self.autoencoder.decode(x_t)
|
||||
|
@@ -381,7 +381,6 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
||||
|
||||
# Compute the time embeddings
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
Reference in New Issue
Block a user