mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00

Add timeout handling to various `generate` functions across multiple files. * **cvae/main.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **flux/dreambooth.py** - Add `timeout` parameter to `generate_progress_images` function. - Implement timeout handling using `signal` module in `generate_progress_images` function. * **musicgen/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **stable_diffusion/txt2image.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llava/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llms/gguf_llm/generate.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **llms/mlx_lm/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import argparse
|
|
import signal
|
|
|
|
from utils import save_audio
|
|
|
|
from musicgen import MusicGen
|
|
|
|
|
|
def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None):
|
|
def handler(signum, frame):
|
|
raise TimeoutError("Generation timed out")
|
|
|
|
if timeout:
|
|
signal.signal(signal.SIGALRM, handler)
|
|
signal.alarm(timeout)
|
|
|
|
try:
|
|
model = MusicGen.from_pretrained(model_name)
|
|
audio = model.generate(text, max_steps=max_steps)
|
|
save_audio(output_path, audio, model.sampling_rate)
|
|
finally:
|
|
if timeout:
|
|
signal.alarm(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", required=False, default="facebook/musicgen-medium")
|
|
parser.add_argument("--text", required=False, default="happy rock")
|
|
parser.add_argument("--output-path", required=False, default="0.wav")
|
|
parser.add_argument("--max-steps", required=False, default=500, type=int)
|
|
parser.add_argument("--timeout", required=False, default=None, type=int)
|
|
args = parser.parse_args()
|
|
main(args.text, args.output_path, args.model, args.max_steps, args.timeout)
|