From f7bbe458ae124af43a4b6fe22cf1680214a9e239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=94=A6=E6=AD=A4?= Date: Tue, 22 Oct 2024 17:06:58 +0800 Subject: [PATCH] Add timeout to generate functions Add timeout handling to various `generate` functions across multiple files. * **cvae/main.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **flux/dreambooth.py** - Add `timeout` parameter to `generate_progress_images` function. - Implement timeout handling using `signal` module in `generate_progress_images` function. * **musicgen/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **stable_diffusion/txt2image.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llava/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llms/gguf_llm/generate.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **llms/mlx_lm/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX). --- cvae/main.py | 27 +++-- flux/dreambooth.py | 56 ++++++---- llava/generate.py | 40 +++++-- llms/gguf_llm/generate.py | 69 +++++++----- llms/mlx_lm/generate.py | 174 ++++++++++++++++-------------- llms/speculative_decoding/main.py | 56 ++++++---- musicgen/generate.py | 23 +++- stable_diffusion/txt2image.py | 154 ++++++++++++++------------ 8 files changed, 360 insertions(+), 239 deletions(-) diff --git a/cvae/main.py b/cvae/main.py index 78ac9b4a..bfc8ed82 100644 --- a/cvae/main.py +++ b/cvae/main.py @@ -4,6 +4,7 @@ import argparse import time from functools import partial from pathlib import Path +import signal import dataset import mlx.core as mx @@ -67,16 +68,28 @@ def generate( model, out_file, num_samples=128, + timeout=None, ): - # Sample from the latent distribution: - z = mx.random.normal([num_samples, model.num_latent_dims]) + def handler(signum, frame): + raise TimeoutError("Generation timed out") - # Decode the latent vectors to images: - images = model.decode(z) + if timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) - # Save all images in a single file - grid_image = grid_image_from_batch(images, num_rows=8) - grid_image.save(out_file) + try: + # Sample from the latent distribution: + z = mx.random.normal([num_samples, model.num_latent_dims]) + + # Decode the latent vectors to images: + images = model.decode(z) + + # Save all images in a single file + grid_image = grid_image_from_batch(images, num_rows=8) + grid_image.save(out_file) + finally: + if timeout: + signal.alarm(0) def main(args): diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 48dcad47..76deea4f 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -4,6 +4,7 @@ import argparse import time from functools import partial from pathlib import Path +import signal import mlx.core as mx import mlx.nn as nn @@ -16,31 +17,42 @@ from PIL import Image from flux import FluxPipeline, Trainer, load_dataset -def generate_progress_images(iteration, flux, args): +def generate_progress_images(iteration, flux, args, timeout=None): """Generate images to monitor the progress of the finetuning.""" - out_dir = Path(args.output_dir) - out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / f"{iteration:07d}_progress.png" - print(f"Generating {str(out_file)}", flush=True) + def handler(signum, frame): + raise TimeoutError("Generation timed out") - # Generate some images and arrange them in a grid - n_rows = 2 - n_images = 4 - x = flux.generate_images( - args.progress_prompt, - n_images, - args.progress_steps, - ) - x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) - B, H, W, C = x.shape - x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) - x = x.reshape(n_rows * H, B // n_rows * W, C) - x = mx.pad(x, [(4, 4), (4, 4), (0, 0)]) - x = (x * 255).astype(mx.uint8) + if timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) - # Save them to disc - im = Image.fromarray(np.array(x)) - im.save(out_file) + try: + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_file = out_dir / f"{iteration:07d}_progress.png" + print(f"Generating {str(out_file)}", flush=True) + + # Generate some images and arrange them in a grid + n_rows = 2 + n_images = 4 + x = flux.generate_images( + args.progress_prompt, + n_images, + args.progress_steps, + ) + x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(n_rows * H, B // n_rows * W, C) + x = mx.pad(x, [(4, 4), (4, 4), (0, 0)]) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(out_file) + finally: + if timeout: + signal.alarm(0) def save_adapters(iteration, flux, args): diff --git a/llava/generate.py b/llava/generate.py index 8067839e..84b764d2 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -3,6 +3,7 @@ import argparse import codecs from pathlib import Path +import signal import mlx.core as mx import requests @@ -49,6 +50,12 @@ def parse_arguments(): default=None, help="End of sequence token for tokenizer", ) + parser.add_argument( + "--timeout", + type=int, + default=None, + help="Timeout in seconds for the generation process.", + ) return parser.parse_args() @@ -119,21 +126,32 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera def main(): args = parse_arguments() - tokenizer_config = {} - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token + def handler(signum, frame): + raise TimeoutError("Generation timed out") - processor, model = load_model(args.model, tokenizer_config) + if args.timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(args.timeout) - prompt = codecs.decode(args.prompt, "unicode_escape") + try: + tokenizer_config = {} + if args.eos_token is not None: + tokenizer_config["eos_token"] = args.eos_token - input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) + processor, model = load_model(args.model, tokenizer_config) - print(prompt) - generated_text = generate_text( - input_ids, pixel_values, model, processor, args.max_tokens, args.temp - ) - print(generated_text) + prompt = codecs.decode(args.prompt, "unicode_escape") + + input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) + + print(prompt) + generated_text = generate_text( + input_ids, pixel_values, model, processor, args.max_tokens, args.temp + ) + print(generated_text) + finally: + if args.timeout: + signal.alarm(0) if __name__ == "__main__": diff --git a/llms/gguf_llm/generate.py b/llms/gguf_llm/generate.py index 7215aa48..c8696518 100644 --- a/llms/gguf_llm/generate.py +++ b/llms/gguf_llm/generate.py @@ -2,6 +2,7 @@ import argparse import time +import signal import mlx.core as mx import models @@ -13,37 +14,49 @@ def generate( prompt: str, max_tokens: int, temp: float = 0.0, + timeout: int = None, ): - prompt = tokenizer.encode(prompt) + def handler(signum, frame): + raise TimeoutError("Generation timed out") - tic = time.time() - tokens = [] - skip = 0 - for token, n in zip( - models.generate(prompt, model, args.temp), - range(args.max_tokens), - ): - if token == tokenizer.eos_token_id: - break + if timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) - if n == 0: - prompt_time = time.time() - tic - tic = time.time() + try: + prompt = tokenizer.encode(prompt) - tokens.append(token.item()) - s = tokenizer.decode(tokens) - print(s[skip:], end="", flush=True) - skip = len(s) - print(tokenizer.decode(tokens)[skip:], flush=True) - gen_time = time.time() - tic - print("=" * 10) - if len(tokens) == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt.size / prompt_time - gen_tps = (len(tokens) - 1) / gen_time - print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {gen_tps:.3f} tokens-per-sec") + tic = time.time() + tokens = [] + skip = 0 + for token, n in zip( + models.generate(prompt, model, args.temp), + range(args.max_tokens), + ): + if token == tokenizer.eos_token_id: + break + + if n == 0: + prompt_time = time.time() - tic + tic = time.time() + + tokens.append(token.item()) + s = tokenizer.decode(tokens) + print(s[skip:], end="", flush=True) + skip = len(s) + print(tokenizer.decode(tokens)[skip:], flush=True) + gen_time = time.time() - tic + print("=" * 10) + if len(tokens) == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt.size / prompt_time + gen_tps = (len(tokens) - 1) / gen_time + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") + finally: + if timeout: + signal.alarm(0) if __name__ == "__main__": @@ -83,4 +96,4 @@ if __name__ == "__main__": args = parser.parse_args() mx.random.seed(args.seed) model, tokenizer = models.load(args.gguf, args.repo) - generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) + generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..253db12c 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -3,6 +3,7 @@ import argparse import json import sys +import signal import mlx.core as mx @@ -107,6 +108,12 @@ def setup_arg_parser(): default=None, help="A file containing saved KV caches to avoid recomputing them", ) + parser.add_argument( + "--timeout", + type=int, + default=None, + help="Timeout in seconds for the generation process.", + ) return parser @@ -146,90 +153,101 @@ def main(): if args.cache_limit_gb is not None: mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the prompt cache and metadata if a cache file is provided - using_cache = args.prompt_cache_file is not None - if using_cache: - prompt_cache, metadata = load_prompt_cache( - args.prompt_cache_file, return_metadata=True - ) + def handler(signum, frame): + raise TimeoutError("Generation timed out") - # Building tokenizer_config - tokenizer_config = ( - {} if not using_cache else json.loads(metadata["tokenizer_config"]) - ) - if args.trust_remote_code: - tokenizer_config["trust_remote_code"] = True - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token + if args.timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(args.timeout) - model_path = args.model - if using_cache: - if model_path is None: - model_path = metadata["model"] - elif model_path != metadata["model"]: - raise ValueError( - f"Providing a different model ({model_path}) than that " - f"used to create the prompt cache ({metadata['model']}) " - "is an error." - ) - model_path = model_path or DEFAULT_MODEL - - model, tokenizer = load( - model_path, - adapter_path=args.adapter_path, - tokenizer_config=tokenizer_config, - ) - - if args.use_default_chat_template: - if tokenizer.chat_template is None: - tokenizer.chat_template = tokenizer.default_chat_template - elif using_cache: - tokenizer.chat_template = metadata["chat_template"] - - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): - messages = [ - { - "role": "user", - "content": sys.stdin.read() if args.prompt == "-" else args.prompt, - } - ] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - # Treat the prompt as a suffix assuming that the prefix is in the - # stored kv cache. + try: + # Load the prompt cache and metadata if a cache file is provided + using_cache = args.prompt_cache_file is not None if using_cache: - test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], - tokenize=False, - add_generation_prompt=True, + prompt_cache, metadata = load_prompt_cache( + args.prompt_cache_file, return_metadata=True ) - prompt = prompt[test_prompt.index("") :] - else: - prompt = args.prompt - if args.colorize and not args.verbose: - raise ValueError("Cannot use --colorize with --verbose=False") - formatter = colorprint_by_t0 if args.colorize else None + # Building tokenizer_config + tokenizer_config = ( + {} if not using_cache else json.loads(metadata["tokenizer_config"]) + ) + if args.trust_remote_code: + tokenizer_config["trust_remote_code"] = True + if args.eos_token is not None: + tokenizer_config["eos_token"] = args.eos_token - response = generate( - model, - tokenizer, - prompt, - args.max_tokens, - verbose=args.verbose, - formatter=formatter, - temp=args.temp, - top_p=args.top_p, - max_kv_size=args.max_kv_size, - prompt_cache=prompt_cache if using_cache else None, - ) - if not args.verbose: - print(response) + model_path = args.model + if using_cache: + if model_path is None: + model_path = metadata["model"] + elif model_path != metadata["model"]: + raise ValueError( + f"Providing a different model ({model_path}) than that " + f"used to create the prompt cache ({metadata['model']}) " + "is an error." + ) + model_path = model_path or DEFAULT_MODEL + + model, tokenizer = load( + model_path, + adapter_path=args.adapter_path, + tokenizer_config=tokenizer_config, + ) + + if args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + elif using_cache: + tokenizer.chat_template = metadata["chat_template"] + + if not args.ignore_chat_template and ( + hasattr(tokenizer, "apply_chat_template") + and tokenizer.chat_template is not None + ): + messages = [ + { + "role": "user", + "content": sys.stdin.read() if args.prompt == "-" else args.prompt, + } + ] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Treat the prompt as a suffix assuming that the prefix is in the + # stored kv cache. + if using_cache: + test_prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": ""}], + tokenize=False, + add_generation_prompt=True, + ) + prompt = prompt[test_prompt.index("") :] + else: + prompt = args.prompt + + if args.colorize and not args.verbose: + raise ValueError("Cannot use --colorize with --verbose=False") + formatter = colorprint_by_t0 if args.colorize else None + + response = generate( + model, + tokenizer, + prompt, + args.max_tokens, + verbose=args.verbose, + formatter=formatter, + temp=args.temp, + top_p=args.top_p, + max_kv_size=args.max_kv_size, + prompt_cache=prompt_cache if using_cache else None, + ) + if not args.verbose: + print(response) + finally: + if args.timeout: + signal.alarm(0) if __name__ == "__main__": diff --git a/llms/speculative_decoding/main.py b/llms/speculative_decoding/main.py index b1da3a5e..fefe46b7 100644 --- a/llms/speculative_decoding/main.py +++ b/llms/speculative_decoding/main.py @@ -1,5 +1,6 @@ import argparse import time +import signal import mlx.core as mx from decoder import SpeculativeDecoder @@ -21,27 +22,38 @@ def load_model(model_name: str): def main(args): mx.random.seed(args.seed) - spec_decoder = SpeculativeDecoder( - model=load_model(args.model_name), - draft_model=load_model(args.draft_model_name), - tokenizer=args.model_name, - delta=args.delta, - num_draft=args.num_draft, - ) + def handler(signum, frame): + raise TimeoutError("Generation timed out") - tic = time.time() - print(args.prompt) - if args.regular_decode: - spec_decoder.generate(args.prompt, max_tokens=args.max_tokens) - else: - stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens) + if args.timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(args.timeout) + + try: + spec_decoder = SpeculativeDecoder( + model=load_model(args.model_name), + draft_model=load_model(args.draft_model_name), + tokenizer=args.model_name, + delta=args.delta, + num_draft=args.num_draft, + ) + + tic = time.time() + print(args.prompt) + if args.regular_decode: + spec_decoder.generate(args.prompt, max_tokens=args.max_tokens) + else: + stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens) + print("=" * 10) + print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.") + print(f"Decoding steps {stats['n_steps']}.") + + toc = time.time() print("=" * 10) - print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.") - print(f"Decoding steps {stats['n_steps']}.") - - toc = time.time() - print("=" * 10) - print(f"Full generation time {toc - tic:.3f}") + print(f"Full generation time {toc - tic:.3f}") + finally: + if args.timeout: + signal.alarm(0) if __name__ == "__main__": @@ -91,5 +103,11 @@ if __name__ == "__main__": action="store_true", help="Use regular decoding instead of speculative decoding.", ) + parser.add_argument( + "--timeout", + type=int, + default=None, + help="Timeout in seconds for the generation process.", + ) args = parser.parse_args() main(args) diff --git a/musicgen/generate.py b/musicgen/generate.py index 5a6b7804..7b4fb3c2 100644 --- a/musicgen/generate.py +++ b/musicgen/generate.py @@ -1,16 +1,28 @@ # Copyright © 2024 Apple Inc. import argparse +import signal from utils import save_audio from musicgen import MusicGen -def main(text: str, output_path: str, model_name: str, max_steps: int): - model = MusicGen.from_pretrained(model_name) - audio = model.generate(text, max_steps=max_steps) - save_audio(output_path, audio, model.sampling_rate) +def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None): + def handler(signum, frame): + raise TimeoutError("Generation timed out") + + if timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) + + try: + model = MusicGen.from_pretrained(model_name) + audio = model.generate(text, max_steps=max_steps) + save_audio(output_path, audio, model.sampling_rate) + finally: + if timeout: + signal.alarm(0) if __name__ == "__main__": @@ -19,5 +31,6 @@ if __name__ == "__main__": parser.add_argument("--text", required=False, default="happy rock") parser.add_argument("--output-path", required=False, default="0.wav") parser.add_argument("--max-steps", required=False, default=500, type=int) + parser.add_argument("--timeout", required=False, default=None, type=int) args = parser.parse_args() - main(args.text, args.output_path, args.model, args.max_steps) + main(args.text, args.output_path, args.model, args.max_steps, args.timeout) diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index 26c757f8..c516b1cc 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import argparse +import signal import mlx.core as mx import mlx.nn as nn @@ -27,82 +28,97 @@ if __name__ == "__main__": parser.add_argument("--preload-models", action="store_true") parser.add_argument("--output", default="out.png") parser.add_argument("--seed", type=int) + parser.add_argument("--timeout", type=int, default=None) parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() - # Load the models - if args.model == "sdxl": - sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) - if args.quantize: - nn.quantize( - sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear) + def handler(signum, frame): + raise TimeoutError("Generation timed out") + + if args.timeout: + signal.signal(signal.SIGALRM, handler) + signal.alarm(args.timeout) + + try: + # Load the models + if args.model == "sdxl": + sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) + if args.quantize: + nn.quantize( + sd.text_encoder_1, + class_predicate=lambda _, m: isinstance(m, nn.Linear), + ) + nn.quantize( + sd.text_encoder_2, + class_predicate=lambda _, m: isinstance(m, nn.Linear), + ) + nn.quantize(sd.unet, group_size=32, bits=8) + args.cfg = args.cfg or 0.0 + args.steps = args.steps or 2 + else: + sd = StableDiffusion( + "stabilityai/stable-diffusion-2-1-base", float16=args.float16 ) - nn.quantize( - sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear) - ) - nn.quantize(sd.unet, group_size=32, bits=8) - args.cfg = args.cfg or 0.0 - args.steps = args.steps or 2 - else: - sd = StableDiffusion( - "stabilityai/stable-diffusion-2-1-base", float16=args.float16 + if args.quantize: + nn.quantize( + sd.text_encoder, + class_predicate=lambda _, m: isinstance(m, nn.Linear), + ) + nn.quantize(sd.unet, group_size=32, bits=8) + args.cfg = args.cfg or 7.5 + args.steps = args.steps or 50 + + # Ensure that models are read in memory if needed + if args.preload_models: + sd.ensure_models_are_loaded() + + # Generate the latent vectors using diffusion + latents = sd.generate_latents( + args.prompt, + n_images=args.n_images, + cfg_weight=args.cfg, + num_steps=args.steps, + seed=args.seed, + negative_text=args.negative_prompt, ) - if args.quantize: - nn.quantize( - sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear) - ) - nn.quantize(sd.unet, group_size=32, bits=8) - args.cfg = args.cfg or 7.5 - args.steps = args.steps or 50 + for x_t in tqdm(latents, total=args.steps): + mx.eval(x_t) - # Ensure that models are read in memory if needed - if args.preload_models: - sd.ensure_models_are_loaded() + # The following is not necessary but it may help in memory + # constrained systems by reusing the memory kept by the unet and the text + # encoders. + if args.model == "sdxl": + del sd.text_encoder_1 + del sd.text_encoder_2 + else: + del sd.text_encoder + del sd.unet + del sd.sampler + peak_mem_unet = mx.metal.get_peak_memory() / 1024**3 - # Generate the latent vectors using diffusion - latents = sd.generate_latents( - args.prompt, - n_images=args.n_images, - cfg_weight=args.cfg, - num_steps=args.steps, - seed=args.seed, - negative_text=args.negative_prompt, - ) - for x_t in tqdm(latents, total=args.steps): - mx.eval(x_t) + # Decode them into images + decoded = [] + for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): + decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size])) + mx.eval(decoded[-1]) + peak_mem_overall = mx.metal.get_peak_memory() / 1024**3 - # The following is not necessary but it may help in memory - # constrained systems by reusing the memory kept by the unet and the text - # encoders. - if args.model == "sdxl": - del sd.text_encoder_1 - del sd.text_encoder_2 - else: - del sd.text_encoder - del sd.unet - del sd.sampler - peak_mem_unet = mx.metal.get_peak_memory() / 1024**3 + # Arrange them on a grid + x = mx.concatenate(decoded, axis=0) + x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) + x = (x * 255).astype(mx.uint8) - # Decode them into images - decoded = [] - for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): - decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size])) - mx.eval(decoded[-1]) - peak_mem_overall = mx.metal.get_peak_memory() / 1024**3 + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(args.output) - # Arrange them on a grid - x = mx.concatenate(decoded, axis=0) - x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)]) - B, H, W, C = x.shape - x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) - x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) - x = (x * 255).astype(mx.uint8) - - # Save them to disc - im = Image.fromarray(np.array(x)) - im.save(args.output) - - # Report the peak memory used during generation - if args.verbose: - print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB") - print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") + # Report the peak memory used during generation + if args.verbose: + print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB") + print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") + finally: + if args.timeout: + signal.alarm(0)