From 3f3741d229f9dbefdfefa3082fbac444248f86be Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 14 Mar 2024 12:22:54 -0700 Subject: [PATCH] Fix requirements and image2image strength/steps mismatch (#585) --- stable_diffusion/image2image.py | 14 +++++++++++++- stable_diffusion/requirements.txt | 2 +- stable_diffusion/txt2image.py | 3 +++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index 03b41332..802dee57 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import argparse +import math import mlx.core as mx import numpy as np @@ -30,6 +31,7 @@ if __name__ == "__main__": parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() + # Load the models if args.model == "sdxl": sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) if args.quantize: @@ -47,6 +49,16 @@ if __name__ == "__main__": QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 7.5 args.steps = args.steps or 50 + + # Fix the steps if they were set too low + if int(args.steps * args.strength) < 1: + args.steps = int(math.ceil(1 / args.strength)) + if args.verbose: + print( + f"Strength {args.strength} is too low so steps were set to {args.steps}" + ) + + # Ensure that models are read in memory if needed if args.preload_models: sd.ensure_models_are_loaded() @@ -62,7 +74,7 @@ if __name__ == "__main__": img = mx.array(np.array(img)) img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1 - # Noise and denoise the latents produced by encoding img. + # Noise and denoise the latents produced by encoding the img. latents = sd.generate_latents_from_image( img, args.prompt, diff --git a/stable_diffusion/requirements.txt b/stable_diffusion/requirements.txt index 18291154..ab85c726 100644 --- a/stable_diffusion/requirements.txt +++ b/stable_diffusion/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.1 +mlx>=0.6 huggingface-hub regex numpy diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index 0640f71f..1566bf6b 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -30,6 +30,7 @@ if __name__ == "__main__": parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() + # Load the models if args.model == "sdxl": sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) if args.quantize: @@ -47,6 +48,8 @@ if __name__ == "__main__": QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 7.5 args.steps = args.steps or 50 + + # Ensure that models are read in memory if needed if args.preload_models: sd.ensure_models_are_loaded()