Fix requirements and image2image strength/steps mismatch (#585)

This commit is contained in:
Angelos Katharopoulos 2024-03-14 12:22:54 -07:00 committed by GitHub
parent e2205beb66
commit 3f3741d229
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 2 deletions

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import argparse
import math
import mlx.core as mx
import numpy as np
@ -30,6 +31,7 @@ if __name__ == "__main__":
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
# Load the models
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
@ -47,6 +49,16 @@ if __name__ == "__main__":
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50
# Fix the steps if they were set too low
if int(args.steps * args.strength) < 1:
args.steps = int(math.ceil(1 / args.strength))
if args.verbose:
print(
f"Strength {args.strength} is too low so steps were set to {args.steps}"
)
# Ensure that models are read in memory if needed
if args.preload_models:
sd.ensure_models_are_loaded()
@ -62,7 +74,7 @@ if __name__ == "__main__":
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
# Noise and denoise the latents produced by encoding img.
# Noise and denoise the latents produced by encoding the img.
latents = sd.generate_latents_from_image(
img,
args.prompt,

View File

@ -1,4 +1,4 @@
mlx>=0.1
mlx>=0.6
huggingface-hub
regex
numpy

View File

@ -30,6 +30,7 @@ if __name__ == "__main__":
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
# Load the models
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
@ -47,6 +48,8 @@ if __name__ == "__main__":
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50
# Ensure that models are read in memory if needed
if args.preload_models:
sd.ensure_models_are_loaded()