mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix requirements and image2image strength/steps mismatch (#585)
This commit is contained in:
parent
e2205beb66
commit
3f3741d229
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -30,6 +31,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--verbose", "-v", action="store_true")
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load the models
|
||||||
if args.model == "sdxl":
|
if args.model == "sdxl":
|
||||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
@ -47,6 +49,16 @@ if __name__ == "__main__":
|
|||||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 7.5
|
args.cfg = args.cfg or 7.5
|
||||||
args.steps = args.steps or 50
|
args.steps = args.steps or 50
|
||||||
|
|
||||||
|
# Fix the steps if they were set too low
|
||||||
|
if int(args.steps * args.strength) < 1:
|
||||||
|
args.steps = int(math.ceil(1 / args.strength))
|
||||||
|
if args.verbose:
|
||||||
|
print(
|
||||||
|
f"Strength {args.strength} is too low so steps were set to {args.steps}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure that models are read in memory if needed
|
||||||
if args.preload_models:
|
if args.preload_models:
|
||||||
sd.ensure_models_are_loaded()
|
sd.ensure_models_are_loaded()
|
||||||
|
|
||||||
@ -62,7 +74,7 @@ if __name__ == "__main__":
|
|||||||
img = mx.array(np.array(img))
|
img = mx.array(np.array(img))
|
||||||
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||||
|
|
||||||
# Noise and denoise the latents produced by encoding img.
|
# Noise and denoise the latents produced by encoding the img.
|
||||||
latents = sd.generate_latents_from_image(
|
latents = sd.generate_latents_from_image(
|
||||||
img,
|
img,
|
||||||
args.prompt,
|
args.prompt,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.1
|
mlx>=0.6
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
regex
|
regex
|
||||||
numpy
|
numpy
|
||||||
|
@ -30,6 +30,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--verbose", "-v", action="store_true")
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load the models
|
||||||
if args.model == "sdxl":
|
if args.model == "sdxl":
|
||||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
@ -47,6 +48,8 @@ if __name__ == "__main__":
|
|||||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 7.5
|
args.cfg = args.cfg or 7.5
|
||||||
args.steps = args.steps or 50
|
args.steps = args.steps or 50
|
||||||
|
|
||||||
|
# Ensure that models are read in memory if needed
|
||||||
if args.preload_models:
|
if args.preload_models:
|
||||||
sd.ensure_models_are_loaded()
|
sd.ensure_models_are_loaded()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user