mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Stable diffusion XL (#516)
This commit is contained in:
committed by
GitHub
parent
8c2cf665ed
commit
3a9e6c3f70
@@ -22,10 +22,19 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--negative_prompt", default="")
|
||||
parser.add_argument("--n_rows", type=int, default=1)
|
||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--no-float16", dest="float16", action="store_false")
|
||||
parser.add_argument("--preload-models", action="store_true")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
sd = StableDiffusion()
|
||||
sd = StableDiffusion("stabilityai/stable-diffusion-2-1-base", float16=args.float16)
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
|
||||
# Read the image
|
||||
img = Image.open(args.image)
|
||||
@@ -52,11 +61,20 @@ if __name__ == "__main__":
|
||||
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
||||
mx.eval(x_t)
|
||||
|
||||
# The following is not necessary but it may help in memory
|
||||
# constrained systems by reusing the memory kept by the unet and the text
|
||||
# encoders.
|
||||
del sd.text_encoder
|
||||
del sd.unet
|
||||
del sd.sampler
|
||||
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
||||
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
@@ -69,3 +87,8 @@ if __name__ == "__main__":
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(args.output)
|
||||
|
||||
# Report the peak memory used during generation
|
||||
if args.verbose:
|
||||
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||
|
||||
Reference in New Issue
Block a user