mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-08 10:14:36 +08:00
Add timeout to generate functions
Add timeout handling to various `generate` functions across multiple files. * **cvae/main.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **flux/dreambooth.py** - Add `timeout` parameter to `generate_progress_images` function. - Implement timeout handling using `signal` module in `generate_progress_images` function. * **musicgen/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **stable_diffusion/txt2image.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llava/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llms/gguf_llm/generate.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **llms/mlx_lm/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
27
cvae/main.py
27
cvae/main.py
@@ -4,6 +4,7 @@ import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import signal
|
||||
|
||||
import dataset
|
||||
import mlx.core as mx
|
||||
@@ -67,16 +68,28 @@ def generate(
|
||||
model,
|
||||
out_file,
|
||||
num_samples=128,
|
||||
timeout=None,
|
||||
):
|
||||
# Sample from the latent distribution:
|
||||
z = mx.random.normal([num_samples, model.num_latent_dims])
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
# Decode the latent vectors to images:
|
||||
images = model.decode(z)
|
||||
if timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
# Save all images in a single file
|
||||
grid_image = grid_image_from_batch(images, num_rows=8)
|
||||
grid_image.save(out_file)
|
||||
try:
|
||||
# Sample from the latent distribution:
|
||||
z = mx.random.normal([num_samples, model.num_latent_dims])
|
||||
|
||||
# Decode the latent vectors to images:
|
||||
images = model.decode(z)
|
||||
|
||||
# Save all images in a single file
|
||||
grid_image = grid_image_from_batch(images, num_rows=8)
|
||||
grid_image.save(out_file)
|
||||
finally:
|
||||
if timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def main(args):
|
||||
|
Reference in New Issue
Block a user