mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Fix image2image for SDXL (#563)
--------- Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
This commit is contained in:
parent
d0fa6cfcae
commit
fe5edee360
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from stable_diffusion import StableDiffusion
|
from stable_diffusion import StableDiffusion, StableDiffusionXL
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -15,10 +15,11 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument("image")
|
parser.add_argument("image")
|
||||||
parser.add_argument("prompt")
|
parser.add_argument("prompt")
|
||||||
|
parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl")
|
||||||
parser.add_argument("--strength", type=float, default=0.9)
|
parser.add_argument("--strength", type=float, default=0.9)
|
||||||
parser.add_argument("--n_images", type=int, default=4)
|
parser.add_argument("--n_images", type=int, default=4)
|
||||||
parser.add_argument("--steps", type=int, default=50)
|
parser.add_argument("--steps", type=int)
|
||||||
parser.add_argument("--cfg", type=float, default=7.5)
|
parser.add_argument("--cfg", type=float)
|
||||||
parser.add_argument("--negative_prompt", default="")
|
parser.add_argument("--negative_prompt", default="")
|
||||||
parser.add_argument("--n_rows", type=int, default=1)
|
parser.add_argument("--n_rows", type=int, default=1)
|
||||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||||
@ -29,10 +30,23 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--verbose", "-v", action="store_true")
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sd = StableDiffusion("stabilityai/stable-diffusion-2-1-base", float16=args.float16)
|
if args.model == "sdxl":
|
||||||
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||||
|
if args.quantize:
|
||||||
|
QuantizedLinear.quantize_module(sd.text_encoder_1)
|
||||||
|
QuantizedLinear.quantize_module(sd.text_encoder_2)
|
||||||
|
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||||
|
args.cfg = args.cfg or 0.0
|
||||||
|
args.steps = args.steps or 2
|
||||||
|
else:
|
||||||
|
sd = StableDiffusion(
|
||||||
|
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||||
|
)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||||
|
args.cfg = args.cfg or 7.5
|
||||||
|
args.steps = args.steps or 50
|
||||||
if args.preload_models:
|
if args.preload_models:
|
||||||
sd.ensure_models_are_loaded()
|
sd.ensure_models_are_loaded()
|
||||||
|
|
||||||
@ -64,6 +78,10 @@ if __name__ == "__main__":
|
|||||||
# The following is not necessary but it may help in memory
|
# The following is not necessary but it may help in memory
|
||||||
# constrained systems by reusing the memory kept by the unet and the text
|
# constrained systems by reusing the memory kept by the unet and the text
|
||||||
# encoders.
|
# encoders.
|
||||||
|
if args.model == "sdxl":
|
||||||
|
del sd.text_encoder_1
|
||||||
|
del sd.text_encoder_2
|
||||||
|
else:
|
||||||
del sd.text_encoder
|
del sd.text_encoder
|
||||||
del sd.unet
|
del sd.unet
|
||||||
del sd.sampler
|
del sd.sampler
|
||||||
|
@ -264,3 +264,42 @@ class StableDiffusionXL(StableDiffusion):
|
|||||||
cfg_weight,
|
cfg_weight,
|
||||||
text_time=text_time,
|
text_time=text_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def generate_latents_from_image(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
text: str,
|
||||||
|
n_images: int = 1,
|
||||||
|
strength: float = 0.8,
|
||||||
|
num_steps: int = 2,
|
||||||
|
cfg_weight: float = 0.0,
|
||||||
|
negative_text: str = "",
|
||||||
|
seed=None,
|
||||||
|
):
|
||||||
|
# Set the PRNG state
|
||||||
|
seed = seed or int(time.time())
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
# Define the num steps and start step
|
||||||
|
start_step = self.sampler.max_time * strength
|
||||||
|
num_steps = int(num_steps * strength)
|
||||||
|
|
||||||
|
# Get the text conditioning
|
||||||
|
conditioning, pooled_conditioning = self._get_text_conditioning(
|
||||||
|
text, n_images, cfg_weight, negative_text
|
||||||
|
)
|
||||||
|
text_time = (
|
||||||
|
pooled_conditioning,
|
||||||
|
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the latents from the input image and add noise according to the
|
||||||
|
# start time.
|
||||||
|
x_0, _ = self.autoencoder.encode(image[None])
|
||||||
|
x_0 = mx.broadcast_to(x_0, (n_images,) + x_0.shape[1:])
|
||||||
|
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
|
||||||
|
|
||||||
|
# Perform the denoising loop
|
||||||
|
yield from self._denoising_loop(
|
||||||
|
x_T, start_step, conditioning, num_steps, cfg_weight, text_time=text_time
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user