mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Fix image2image for SDXL (#563)
--------- Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from stable_diffusion import StableDiffusion
|
||||
from stable_diffusion import StableDiffusion, StableDiffusionXL
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -15,10 +15,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("image")
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl")
|
||||
parser.add_argument("--strength", type=float, default=0.9)
|
||||
parser.add_argument("--n_images", type=int, default=4)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
parser.add_argument("--cfg", type=float, default=7.5)
|
||||
parser.add_argument("--steps", type=int)
|
||||
parser.add_argument("--cfg", type=float)
|
||||
parser.add_argument("--negative_prompt", default="")
|
||||
parser.add_argument("--n_rows", type=int, default=1)
|
||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||
@@ -29,10 +30,23 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
sd = StableDiffusion("stabilityai/stable-diffusion-2-1-base", float16=args.float16)
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
if args.model == "sdxl":
|
||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder_1)
|
||||
QuantizedLinear.quantize_module(sd.text_encoder_2)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 0.0
|
||||
args.steps = args.steps or 2
|
||||
else:
|
||||
sd = StableDiffusion(
|
||||
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||
)
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 7.5
|
||||
args.steps = args.steps or 50
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
|
||||
@@ -64,7 +78,11 @@ if __name__ == "__main__":
|
||||
# The following is not necessary but it may help in memory
|
||||
# constrained systems by reusing the memory kept by the unet and the text
|
||||
# encoders.
|
||||
del sd.text_encoder
|
||||
if args.model == "sdxl":
|
||||
del sd.text_encoder_1
|
||||
del sd.text_encoder_2
|
||||
else:
|
||||
del sd.text_encoder
|
||||
del sd.unet
|
||||
del sd.sampler
|
||||
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
||||
|
Reference in New Issue
Block a user