Add timeout to generate functions

Add timeout handling to various `generate` functions across multiple files.

* **cvae/main.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **flux/dreambooth.py**
  - Add `timeout` parameter to `generate_progress_images` function.
  - Implement timeout handling using `signal` module in `generate_progress_images` function.

* **musicgen/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **stable_diffusion/txt2image.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llava/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llms/gguf_llm/generate.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **llms/mlx_lm/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
锦此 2024-10-22 17:06:58 +08:00
parent 743763bc2e
commit f7bbe458ae
8 changed files with 360 additions and 239 deletions

View File

@ -4,6 +4,7 @@ import argparse
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import signal
import dataset import dataset
import mlx.core as mx import mlx.core as mx
@ -67,16 +68,28 @@ def generate(
model, model,
out_file, out_file,
num_samples=128, num_samples=128,
timeout=None,
): ):
# Sample from the latent distribution: def handler(signum, frame):
z = mx.random.normal([num_samples, model.num_latent_dims]) raise TimeoutError("Generation timed out")
# Decode the latent vectors to images: if timeout:
images = model.decode(z) signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
# Save all images in a single file try:
grid_image = grid_image_from_batch(images, num_rows=8) # Sample from the latent distribution:
grid_image.save(out_file) z = mx.random.normal([num_samples, model.num_latent_dims])
# Decode the latent vectors to images:
images = model.decode(z)
# Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file)
finally:
if timeout:
signal.alarm(0)
def main(args): def main(args):

View File

@ -4,6 +4,7 @@ import argparse
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import signal
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -16,31 +17,42 @@ from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args, timeout=None):
"""Generate images to monitor the progress of the finetuning.""" """Generate images to monitor the progress of the finetuning."""
out_dir = Path(args.output_dir) def handler(signum, frame):
out_dir.mkdir(parents=True, exist_ok=True) raise TimeoutError("Generation timed out")
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid if timeout:
n_rows = 2 signal.signal(signal.SIGALRM, handler)
n_images = 4 signal.alarm(timeout)
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
# Save them to disc try:
im = Image.fromarray(np.array(x)) out_dir = Path(args.output_dir)
im.save(out_file) out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
finally:
if timeout:
signal.alarm(0)
def save_adapters(iteration, flux, args): def save_adapters(iteration, flux, args):

View File

@ -3,6 +3,7 @@
import argparse import argparse
import codecs import codecs
from pathlib import Path from pathlib import Path
import signal
import mlx.core as mx import mlx.core as mx
import requests import requests
@ -49,6 +50,12 @@ def parse_arguments():
default=None, default=None,
help="End of sequence token for tokenizer", help="End of sequence token for tokenizer",
) )
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser.parse_args() return parser.parse_args()
@ -119,21 +126,32 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
def main(): def main():
args = parse_arguments() args = parse_arguments()
tokenizer_config = {} def handler(signum, frame):
if args.eos_token is not None: raise TimeoutError("Generation timed out")
tokenizer_config["eos_token"] = args.eos_token
processor, model = load_model(args.model, tokenizer_config) if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
prompt = codecs.decode(args.prompt, "unicode_escape") try:
tokenizer_config = {}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) processor, model = load_model(args.model, tokenizer_config)
print(prompt) prompt = codecs.decode(args.prompt, "unicode_escape")
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
)
print(generated_text) print(prompt)
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,7 @@
import argparse import argparse
import time import time
import signal
import mlx.core as mx import mlx.core as mx
import models import models
@ -13,37 +14,49 @@ def generate(
prompt: str, prompt: str,
max_tokens: int, max_tokens: int,
temp: float = 0.0, temp: float = 0.0,
timeout: int = None,
): ):
prompt = tokenizer.encode(prompt) def handler(signum, frame):
raise TimeoutError("Generation timed out")
tic = time.time() if timeout:
tokens = [] signal.signal(signal.SIGALRM, handler)
skip = 0 signal.alarm(timeout)
for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0: try:
prompt_time = time.time() - tic prompt = tokenizer.encode(prompt)
tic = time.time()
tokens.append(token.item()) tic = time.time()
s = tokenizer.decode(tokens) tokens = []
print(s[skip:], end="", flush=True) skip = 0
skip = len(s) for token, n in zip(
print(tokenizer.decode(tokens)[skip:], flush=True) models.generate(prompt, model, args.temp),
gen_time = time.time() - tic range(args.max_tokens),
print("=" * 10) ):
if len(tokens) == 0: if token == tokenizer.eos_token_id:
print("No tokens generated for this prompt") break
return
prompt_tps = prompt.size / prompt_time if n == 0:
gen_tps = (len(tokens) - 1) / gen_time prompt_time = time.time() - tic
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") tic = time.time()
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":
@ -83,4 +96,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = models.load(args.gguf, args.repo) model, tokenizer = models.load(args.gguf, args.repo)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)

View File

@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import sys import sys
import signal
import mlx.core as mx import mlx.core as mx
@ -107,6 +108,12 @@ def setup_arg_parser():
default=None, default=None,
help="A file containing saved KV caches to avoid recomputing them", help="A file containing saved KV caches to avoid recomputing them",
) )
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser return parser
@ -146,90 +153,101 @@ def main():
if args.cache_limit_gb is not None: if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the prompt cache and metadata if a cache file is provided def handler(signum, frame):
using_cache = args.prompt_cache_file is not None raise TimeoutError("Generation timed out")
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True
)
# Building tokenizer_config if args.timeout:
tokenizer_config = ( signal.signal(signal.SIGALRM, handler)
{} if not using_cache else json.loads(metadata["tokenizer_config"]) signal.alarm(args.timeout)
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model try:
if using_cache: # Load the prompt cache and metadata if a cache file is provided
if model_path is None: using_cache = args.prompt_cache_file is not None
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if using_cache: if using_cache:
test_prompt = tokenizer.apply_chat_template( prompt_cache, metadata = load_prompt_cache(
[{"role": "user", "content": "<query>"}], args.prompt_cache_file, return_metadata=True
tokenize=False,
add_generation_prompt=True,
) )
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
if args.colorize and not args.verbose: # Building tokenizer_config
raise ValueError("Cannot use --colorize with --verbose=False") tokenizer_config = (
formatter = colorprint_by_t0 if args.colorize else None {} if not using_cache else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
response = generate( model_path = args.model
model, if using_cache:
tokenizer, if model_path is None:
prompt, model_path = metadata["model"]
args.max_tokens, elif model_path != metadata["model"]:
verbose=args.verbose, raise ValueError(
formatter=formatter, f"Providing a different model ({model_path}) than that "
temp=args.temp, f"used to create the prompt cache ({metadata['model']}) "
top_p=args.top_p, "is an error."
max_kv_size=args.max_kv_size, )
prompt_cache=prompt_cache if using_cache else None, model_path = model_path or DEFAULT_MODEL
)
if not args.verbose: model, tokenizer = load(
print(response) model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if using_cache:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
)
if not args.verbose:
print(response)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,6 @@
import argparse import argparse
import time import time
import signal
import mlx.core as mx import mlx.core as mx
from decoder import SpeculativeDecoder from decoder import SpeculativeDecoder
@ -21,27 +22,38 @@ def load_model(model_name: str):
def main(args): def main(args):
mx.random.seed(args.seed) mx.random.seed(args.seed)
spec_decoder = SpeculativeDecoder( def handler(signum, frame):
model=load_model(args.model_name), raise TimeoutError("Generation timed out")
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
tic = time.time() if args.timeout:
print(args.prompt) signal.signal(signal.SIGALRM, handler)
if args.regular_decode: signal.alarm(args.timeout)
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else: try:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens) spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
tic = time.time()
print(args.prompt)
if args.regular_decode:
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")
toc = time.time()
print("=" * 10) print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.") print(f"Full generation time {toc - tic:.3f}")
print(f"Decoding steps {stats['n_steps']}.") finally:
if args.timeout:
toc = time.time() signal.alarm(0)
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")
if __name__ == "__main__": if __name__ == "__main__":
@ -91,5 +103,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Use regular decoding instead of speculative decoding.", help="Use regular decoding instead of speculative decoding.",
) )
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,16 +1,28 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import signal
from utils import save_audio from utils import save_audio
from musicgen import MusicGen from musicgen import MusicGen
def main(text: str, output_path: str, model_name: str, max_steps: int): def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None):
model = MusicGen.from_pretrained(model_name) def handler(signum, frame):
audio = model.generate(text, max_steps=max_steps) raise TimeoutError("Generation timed out")
save_audio(output_path, audio, model.sampling_rate)
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate)
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":
@ -19,5 +31,6 @@ if __name__ == "__main__":
parser.add_argument("--text", required=False, default="happy rock") parser.add_argument("--text", required=False, default="happy rock")
parser.add_argument("--output-path", required=False, default="0.wav") parser.add_argument("--output-path", required=False, default="0.wav")
parser.add_argument("--max-steps", required=False, default=500, type=int) parser.add_argument("--max-steps", required=False, default=500, type=int)
parser.add_argument("--timeout", required=False, default=None, type=int)
args = parser.parse_args() args = parser.parse_args()
main(args.text, args.output_path, args.model, args.max_steps) main(args.text, args.output_path, args.model, args.max_steps, args.timeout)

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import signal
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -27,82 +28,97 @@ if __name__ == "__main__":
parser.add_argument("--preload-models", action="store_true") parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png") parser.add_argument("--output", default="out.png")
parser.add_argument("--seed", type=int) parser.add_argument("--seed", type=int)
parser.add_argument("--timeout", type=int, default=None)
parser.add_argument("--verbose", "-v", action="store_true") parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args() args = parser.parse_args()
# Load the models def handler(signum, frame):
if args.model == "sdxl": raise TimeoutError("Generation timed out")
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize: if args.timeout:
nn.quantize( signal.signal(signal.SIGALRM, handler)
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear) signal.alarm(args.timeout)
try:
# Load the models
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
nn.quantize(
sd.text_encoder_1,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(
sd.text_encoder_2,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
args.steps = args.steps or 2
else:
sd = StableDiffusion(
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
) )
nn.quantize( if args.quantize:
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear) nn.quantize(
) sd.text_encoder,
nn.quantize(sd.unet, group_size=32, bits=8) class_predicate=lambda _, m: isinstance(m, nn.Linear),
args.cfg = args.cfg or 0.0 )
args.steps = args.steps or 2 nn.quantize(sd.unet, group_size=32, bits=8)
else: args.cfg = args.cfg or 7.5
sd = StableDiffusion( args.steps = args.steps or 50
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
# Ensure that models are read in memory if needed
if args.preload_models:
sd.ensure_models_are_loaded()
# Generate the latent vectors using diffusion
latents = sd.generate_latents(
args.prompt,
n_images=args.n_images,
cfg_weight=args.cfg,
num_steps=args.steps,
seed=args.seed,
negative_text=args.negative_prompt,
) )
if args.quantize: for x_t in tqdm(latents, total=args.steps):
nn.quantize( mx.eval(x_t)
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50
# Ensure that models are read in memory if needed # The following is not necessary but it may help in memory
if args.preload_models: # constrained systems by reusing the memory kept by the unet and the text
sd.ensure_models_are_loaded() # encoders.
if args.model == "sdxl":
del sd.text_encoder_1
del sd.text_encoder_2
else:
del sd.text_encoder
del sd.unet
del sd.sampler
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
# Generate the latent vectors using diffusion # Decode them into images
latents = sd.generate_latents( decoded = []
args.prompt, for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
n_images=args.n_images, decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
cfg_weight=args.cfg, mx.eval(decoded[-1])
num_steps=args.steps, peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
seed=args.seed,
negative_text=args.negative_prompt,
)
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# The following is not necessary but it may help in memory # Arrange them on a grid
# constrained systems by reusing the memory kept by the unet and the text x = mx.concatenate(decoded, axis=0)
# encoders. x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
if args.model == "sdxl": B, H, W, C = x.shape
del sd.text_encoder_1 x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
del sd.text_encoder_2 x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
else: x = (x * 255).astype(mx.uint8)
del sd.text_encoder
del sd.unet
del sd.sampler
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
# Decode them into images # Save them to disc
decoded = [] im = Image.fromarray(np.array(x))
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): im.save(args.output)
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
# Arrange them on a grid # Report the peak memory used during generation
x = mx.concatenate(decoded, axis=0) if args.verbose:
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)]) print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
B, H, W, C = x.shape print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) finally:
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) if args.timeout:
x = (x * 255).astype(mx.uint8) signal.alarm(0)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")