mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix requirements and image2image strength/steps mismatch (#585)
This commit is contained in:
parent
e2205beb66
commit
3f3741d229
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@ -30,6 +31,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the models
|
||||
if args.model == "sdxl":
|
||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||
if args.quantize:
|
||||
@ -47,6 +49,16 @@ if __name__ == "__main__":
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 7.5
|
||||
args.steps = args.steps or 50
|
||||
|
||||
# Fix the steps if they were set too low
|
||||
if int(args.steps * args.strength) < 1:
|
||||
args.steps = int(math.ceil(1 / args.strength))
|
||||
if args.verbose:
|
||||
print(
|
||||
f"Strength {args.strength} is too low so steps were set to {args.steps}"
|
||||
)
|
||||
|
||||
# Ensure that models are read in memory if needed
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
|
||||
@ -62,7 +74,7 @@ if __name__ == "__main__":
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||
|
||||
# Noise and denoise the latents produced by encoding img.
|
||||
# Noise and denoise the latents produced by encoding the img.
|
||||
latents = sd.generate_latents_from_image(
|
||||
img,
|
||||
args.prompt,
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.1
|
||||
mlx>=0.6
|
||||
huggingface-hub
|
||||
regex
|
||||
numpy
|
||||
|
@ -30,6 +30,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the models
|
||||
if args.model == "sdxl":
|
||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||
if args.quantize:
|
||||
@ -47,6 +48,8 @@ if __name__ == "__main__":
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 7.5
|
||||
args.steps = args.steps or 50
|
||||
|
||||
# Ensure that models are read in memory if needed
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user