Add timeout to generate functions

Add timeout handling to various `generate` functions across multiple files.

* **cvae/main.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **flux/dreambooth.py**
  - Add `timeout` parameter to `generate_progress_images` function.
  - Implement timeout handling using `signal` module in `generate_progress_images` function.

* **musicgen/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **stable_diffusion/txt2image.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llava/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llms/gguf_llm/generate.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **llms/mlx_lm/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
锦此 2024-10-22 17:06:58 +08:00
parent 743763bc2e
commit f7bbe458ae
8 changed files with 360 additions and 239 deletions

View File

@ -4,6 +4,7 @@ import argparse
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import signal
import dataset import dataset
import mlx.core as mx import mlx.core as mx
@ -67,7 +68,16 @@ def generate(
model, model,
out_file, out_file,
num_samples=128, num_samples=128,
timeout=None,
): ):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
# Sample from the latent distribution: # Sample from the latent distribution:
z = mx.random.normal([num_samples, model.num_latent_dims]) z = mx.random.normal([num_samples, model.num_latent_dims])
@ -77,6 +87,9 @@ def generate(
# Save all images in a single file # Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8) grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file) grid_image.save(out_file)
finally:
if timeout:
signal.alarm(0)
def main(args): def main(args):

View File

@ -4,6 +4,7 @@ import argparse
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import signal
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -16,8 +17,16 @@ from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args, timeout=None):
"""Generate images to monitor the progress of the finetuning.""" """Generate images to monitor the progress of the finetuning."""
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png" out_file = out_dir / f"{iteration:07d}_progress.png"
@ -41,6 +50,9 @@ def generate_progress_images(iteration, flux, args):
# Save them to disc # Save them to disc
im = Image.fromarray(np.array(x)) im = Image.fromarray(np.array(x))
im.save(out_file) im.save(out_file)
finally:
if timeout:
signal.alarm(0)
def save_adapters(iteration, flux, args): def save_adapters(iteration, flux, args):

View File

@ -3,6 +3,7 @@
import argparse import argparse
import codecs import codecs
from pathlib import Path from pathlib import Path
import signal
import mlx.core as mx import mlx.core as mx
import requests import requests
@ -49,6 +50,12 @@ def parse_arguments():
default=None, default=None,
help="End of sequence token for tokenizer", help="End of sequence token for tokenizer",
) )
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser.parse_args() return parser.parse_args()
@ -119,6 +126,14 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
def main(): def main():
args = parse_arguments() args = parse_arguments()
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
tokenizer_config = {} tokenizer_config = {}
if args.eos_token is not None: if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token tokenizer_config["eos_token"] = args.eos_token
@ -134,6 +149,9 @@ def main():
input_ids, pixel_values, model, processor, args.max_tokens, args.temp input_ids, pixel_values, model, processor, args.max_tokens, args.temp
) )
print(generated_text) print(generated_text)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,7 @@
import argparse import argparse
import time import time
import signal
import mlx.core as mx import mlx.core as mx
import models import models
@ -13,7 +14,16 @@ def generate(
prompt: str, prompt: str,
max_tokens: int, max_tokens: int,
temp: float = 0.0, temp: float = 0.0,
timeout: int = None,
): ):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
prompt = tokenizer.encode(prompt) prompt = tokenizer.encode(prompt)
tic = time.time() tic = time.time()
@ -44,6 +54,9 @@ def generate(
gen_tps = (len(tokens) - 1) / gen_time gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec")
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":
@ -83,4 +96,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = models.load(args.gguf, args.repo) model, tokenizer = models.load(args.gguf, args.repo)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)

View File

@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import sys import sys
import signal
import mlx.core as mx import mlx.core as mx
@ -107,6 +108,12 @@ def setup_arg_parser():
default=None, default=None,
help="A file containing saved KV caches to avoid recomputing them", help="A file containing saved KV caches to avoid recomputing them",
) )
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser return parser
@ -146,6 +153,14 @@ def main():
if args.cache_limit_gb is not None: if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
# Load the prompt cache and metadata if a cache file is provided # Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None using_cache = args.prompt_cache_file is not None
if using_cache: if using_cache:
@ -230,6 +245,9 @@ def main():
) )
if not args.verbose: if not args.verbose:
print(response) print(response)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,6 @@
import argparse import argparse
import time import time
import signal
import mlx.core as mx import mlx.core as mx
from decoder import SpeculativeDecoder from decoder import SpeculativeDecoder
@ -21,6 +22,14 @@ def load_model(model_name: str):
def main(args): def main(args):
mx.random.seed(args.seed) mx.random.seed(args.seed)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
spec_decoder = SpeculativeDecoder( spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name), model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name), draft_model=load_model(args.draft_model_name),
@ -42,6 +51,9 @@ def main(args):
toc = time.time() toc = time.time()
print("=" * 10) print("=" * 10)
print(f"Full generation time {toc - tic:.3f}") print(f"Full generation time {toc - tic:.3f}")
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":
@ -91,5 +103,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Use regular decoding instead of speculative decoding.", help="Use regular decoding instead of speculative decoding.",
) )
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,16 +1,28 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import signal
from utils import save_audio from utils import save_audio
from musicgen import MusicGen from musicgen import MusicGen
def main(text: str, output_path: str, model_name: str, max_steps: int): def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
model = MusicGen.from_pretrained(model_name) model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps) audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate) save_audio(output_path, audio, model.sampling_rate)
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":
@ -19,5 +31,6 @@ if __name__ == "__main__":
parser.add_argument("--text", required=False, default="happy rock") parser.add_argument("--text", required=False, default="happy rock")
parser.add_argument("--output-path", required=False, default="0.wav") parser.add_argument("--output-path", required=False, default="0.wav")
parser.add_argument("--max-steps", required=False, default=500, type=int) parser.add_argument("--max-steps", required=False, default=500, type=int)
parser.add_argument("--timeout", required=False, default=None, type=int)
args = parser.parse_args() args = parser.parse_args()
main(args.text, args.output_path, args.model, args.max_steps) main(args.text, args.output_path, args.model, args.max_steps, args.timeout)

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import signal
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -27,18 +28,29 @@ if __name__ == "__main__":
parser.add_argument("--preload-models", action="store_true") parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png") parser.add_argument("--output", default="out.png")
parser.add_argument("--seed", type=int) parser.add_argument("--seed", type=int)
parser.add_argument("--timeout", type=int, default=None)
parser.add_argument("--verbose", "-v", action="store_true") parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args() args = parser.parse_args()
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
# Load the models # Load the models
if args.model == "sdxl": if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize: if args.quantize:
nn.quantize( nn.quantize(
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear) sd.text_encoder_1,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
) )
nn.quantize( nn.quantize(
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear) sd.text_encoder_2,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
) )
nn.quantize(sd.unet, group_size=32, bits=8) nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0 args.cfg = args.cfg or 0.0
@ -49,7 +61,8 @@ if __name__ == "__main__":
) )
if args.quantize: if args.quantize:
nn.quantize( nn.quantize(
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear) sd.text_encoder,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
) )
nn.quantize(sd.unet, group_size=32, bits=8) nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5 args.cfg = args.cfg or 7.5
@ -106,3 +119,6 @@ if __name__ == "__main__":
if args.verbose: if args.verbose:
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB") print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
finally:
if args.timeout:
signal.alarm(0)