Support Hugging Face models (#215)

* support hf direct models
This commit is contained in:
Awni Hannun
2024-01-03 15:13:26 -08:00
committed by GitHub
parent 1d09c4fecd
commit a5d6d0436c
16 changed files with 654 additions and 27 deletions

View File

@@ -79,9 +79,13 @@ class StableDiffusion:
return x_t_prev
def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5):
def _denoising_loop(
self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5
):
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype):
for t, t_prev in self.sampler.timesteps(
num_steps, start_time=T, dtype=self.dtype
):
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
yield x_t
@@ -100,7 +104,9 @@ class StableDiffusion:
mx.random.seed(seed)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
# Create the latent variables
x_T = self.sampler.sample_prior(
@@ -108,7 +114,9 @@ class StableDiffusion:
)
# Perform the denoising loop
yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight)
yield from self._denoising_loop(
x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight
)
def generate_latents_from_image(
self,
@@ -130,16 +138,20 @@ class StableDiffusion:
num_steps = int(num_steps * strength)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
# Get the latents from the input image and add noise according to the
# start time.
x_0, _ = self.autoencoder.encode(image[None])
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
# Perform the denoising loop
yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight)
yield from self._denoising_loop(
x_T, start_step, conditioning, num_steps, cfg_weight
)
def decode(self, x_t):
x = self.autoencoder.decode(x_t)

View File

@@ -381,7 +381,6 @@ class UNetModel(nn.Module):
)
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)