Fix image2image for SDXL (#563)

---------

Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
This commit is contained in:
devonthomas35
2024-03-11 12:18:47 -07:00
committed by GitHub
parent d0fa6cfcae
commit fe5edee360
2 changed files with 65 additions and 8 deletions

View File

@@ -264,3 +264,42 @@ class StableDiffusionXL(StableDiffusion):
cfg_weight,
text_time=text_time,
)
def generate_latents_from_image(
self,
image,
text: str,
n_images: int = 1,
strength: float = 0.8,
num_steps: int = 2,
cfg_weight: float = 0.0,
negative_text: str = "",
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Define the num steps and start step
start_step = self.sampler.max_time * strength
num_steps = int(num_steps * strength)
# Get the text conditioning
conditioning, pooled_conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
text_time = (
pooled_conditioning,
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
)
# Get the latents from the input image and add noise according to the
# start time.
x_0, _ = self.autoencoder.encode(image[None])
x_0 = mx.broadcast_to(x_0, (n_images,) + x_0.shape[1:])
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
# Perform the denoising loop
yield from self._denoising_loop(
x_T, start_step, conditioning, num_steps, cfg_weight, text_time=text_time
)