Add timeout to generate functions

Add timeout handling to various `generate` functions across multiple files.

* **cvae/main.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **flux/dreambooth.py**
  - Add `timeout` parameter to `generate_progress_images` function.
  - Implement timeout handling using `signal` module in `generate_progress_images` function.

* **musicgen/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **stable_diffusion/txt2image.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llava/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llms/gguf_llm/generate.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **llms/mlx_lm/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
锦此
2024-10-22 17:06:58 +08:00
parent 743763bc2e
commit f7bbe458ae
8 changed files with 360 additions and 239 deletions

View File

@@ -4,6 +4,7 @@ import argparse
import time
from functools import partial
from pathlib import Path
import signal
import dataset
import mlx.core as mx
@@ -67,16 +68,28 @@ def generate(
model,
out_file,
num_samples=128,
timeout=None,
):
# Sample from the latent distribution:
z = mx.random.normal([num_samples, model.num_latent_dims])
def handler(signum, frame):
raise TimeoutError("Generation timed out")
# Decode the latent vectors to images:
images = model.decode(z)
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
# Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file)
try:
# Sample from the latent distribution:
z = mx.random.normal([num_samples, model.num_latent_dims])
# Decode the latent vectors to images:
images = model.decode(z)
# Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file)
finally:
if timeout:
signal.alarm(0)
def main(args):