mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 22:58:08 +08:00
Fix image2image for SDXL (#563)
--------- Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
This commit is contained in:
@@ -264,3 +264,42 @@ class StableDiffusionXL(StableDiffusion):
|
||||
cfg_weight,
|
||||
text_time=text_time,
|
||||
)
|
||||
|
||||
def generate_latents_from_image(
|
||||
self,
|
||||
image,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
strength: float = 0.8,
|
||||
num_steps: int = 2,
|
||||
cfg_weight: float = 0.0,
|
||||
negative_text: str = "",
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Define the num steps and start step
|
||||
start_step = self.sampler.max_time * strength
|
||||
num_steps = int(num_steps * strength)
|
||||
|
||||
# Get the text conditioning
|
||||
conditioning, pooled_conditioning = self._get_text_conditioning(
|
||||
text, n_images, cfg_weight, negative_text
|
||||
)
|
||||
text_time = (
|
||||
pooled_conditioning,
|
||||
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
|
||||
)
|
||||
|
||||
# Get the latents from the input image and add noise according to the
|
||||
# start time.
|
||||
x_0, _ = self.autoencoder.encode(image[None])
|
||||
x_0 = mx.broadcast_to(x_0, (n_images,) + x_0.shape[1:])
|
||||
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
|
||||
|
||||
# Perform the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, start_step, conditioning, num_steps, cfg_weight, text_time=text_time
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user