diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index 03b41332..802dee57 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import argparse +import math import mlx.core as mx import numpy as np @@ -30,6 +31,7 @@ if __name__ == "__main__": parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() + # Load the models if args.model == "sdxl": sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) if args.quantize: @@ -47,6 +49,16 @@ if __name__ == "__main__": QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 7.5 args.steps = args.steps or 50 + + # Fix the steps if they were set too low + if int(args.steps * args.strength) < 1: + args.steps = int(math.ceil(1 / args.strength)) + if args.verbose: + print( + f"Strength {args.strength} is too low so steps were set to {args.steps}" + ) + + # Ensure that models are read in memory if needed if args.preload_models: sd.ensure_models_are_loaded() @@ -62,7 +74,7 @@ if __name__ == "__main__": img = mx.array(np.array(img)) img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1 - # Noise and denoise the latents produced by encoding img. + # Noise and denoise the latents produced by encoding the img. latents = sd.generate_latents_from_image( img, args.prompt, diff --git a/stable_diffusion/requirements.txt b/stable_diffusion/requirements.txt index 18291154..ab85c726 100644 --- a/stable_diffusion/requirements.txt +++ b/stable_diffusion/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.1 +mlx>=0.6 huggingface-hub regex numpy diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index 0640f71f..1566bf6b 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -30,6 +30,7 @@ if __name__ == "__main__": parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() + # Load the models if args.model == "sdxl": sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) if args.quantize: @@ -47,6 +48,8 @@ if __name__ == "__main__": QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 7.5 args.steps = args.steps or 50 + + # Ensure that models are read in memory if needed if args.preload_models: sd.ensure_models_are_loaded()