Add timeout to generate functions

Add timeout handling to various `generate` functions across multiple files.

* **cvae/main.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **flux/dreambooth.py**
  - Add `timeout` parameter to `generate_progress_images` function.
  - Implement timeout handling using `signal` module in `generate_progress_images` function.

* **musicgen/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **stable_diffusion/txt2image.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llava/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llms/gguf_llm/generate.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **llms/mlx_lm/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
锦此 2024-10-22 17:06:58 +08:00
parent 743763bc2e
commit f7bbe458ae
8 changed files with 360 additions and 239 deletions

View File

@ -4,6 +4,7 @@ import argparse
import time
from functools import partial
from pathlib import Path
import signal
import dataset
import mlx.core as mx
@ -67,7 +68,16 @@ def generate(
model,
out_file,
num_samples=128,
timeout=None,
):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
# Sample from the latent distribution:
z = mx.random.normal([num_samples, model.num_latent_dims])
@ -77,6 +87,9 @@ def generate(
# Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file)
finally:
if timeout:
signal.alarm(0)
def main(args):

View File

@ -4,6 +4,7 @@ import argparse
import time
from functools import partial
from pathlib import Path
import signal
import mlx.core as mx
import mlx.nn as nn
@ -16,8 +17,16 @@ from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args):
def generate_progress_images(iteration, flux, args, timeout=None):
"""Generate images to monitor the progress of the finetuning."""
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
@ -41,6 +50,9 @@ def generate_progress_images(iteration, flux, args):
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
finally:
if timeout:
signal.alarm(0)
def save_adapters(iteration, flux, args):

View File

@ -3,6 +3,7 @@
import argparse
import codecs
from pathlib import Path
import signal
import mlx.core as mx
import requests
@ -49,6 +50,12 @@ def parse_arguments():
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser.parse_args()
@ -119,6 +126,14 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
def main():
args = parse_arguments()
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
tokenizer_config = {}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
@ -134,6 +149,9 @@ def main():
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__":

View File

@ -2,6 +2,7 @@
import argparse
import time
import signal
import mlx.core as mx
import models
@ -13,7 +14,16 @@ def generate(
prompt: str,
max_tokens: int,
temp: float = 0.0,
timeout: int = None,
):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
prompt = tokenizer.encode(prompt)
tic = time.time()
@ -44,6 +54,9 @@ def generate(
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__":
@ -83,4 +96,4 @@ if __name__ == "__main__":
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = models.load(args.gguf, args.repo)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)

View File

@ -3,6 +3,7 @@
import argparse
import json
import sys
import signal
import mlx.core as mx
@ -107,6 +108,12 @@ def setup_arg_parser():
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser
@ -146,6 +153,14 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
@ -230,6 +245,9 @@ def main():
)
if not args.verbose:
print(response)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__":

View File

@ -1,5 +1,6 @@
import argparse
import time
import signal
import mlx.core as mx
from decoder import SpeculativeDecoder
@ -21,6 +22,14 @@ def load_model(model_name: str):
def main(args):
mx.random.seed(args.seed)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
@ -42,6 +51,9 @@ def main(args):
toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__":
@ -91,5 +103,11 @@ if __name__ == "__main__":
action="store_true",
help="Use regular decoding instead of speculative decoding.",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
args = parser.parse_args()
main(args)

View File

@ -1,16 +1,28 @@
# Copyright © 2024 Apple Inc.
import argparse
import signal
from utils import save_audio
from musicgen import MusicGen
def main(text: str, output_path: str, model_name: str, max_steps: int):
def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate)
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__":
@ -19,5 +31,6 @@ if __name__ == "__main__":
parser.add_argument("--text", required=False, default="happy rock")
parser.add_argument("--output-path", required=False, default="0.wav")
parser.add_argument("--max-steps", required=False, default=500, type=int)
parser.add_argument("--timeout", required=False, default=None, type=int)
args = parser.parse_args()
main(args.text, args.output_path, args.model, args.max_steps)
main(args.text, args.output_path, args.model, args.max_steps, args.timeout)

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import argparse
import signal
import mlx.core as mx
import mlx.nn as nn
@ -27,18 +28,29 @@ if __name__ == "__main__":
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--seed", type=int)
parser.add_argument("--timeout", type=int, default=None)
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
# Load the models
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
nn.quantize(
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
sd.text_encoder_1,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
sd.text_encoder_2,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
@ -49,7 +61,8 @@ if __name__ == "__main__":
)
if args.quantize:
nn.quantize(
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
sd.text_encoder,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
@ -106,3 +119,6 @@ if __name__ == "__main__":
if args.verbose:
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
finally:
if args.timeout:
signal.alarm(0)