mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Add timeout to generate functions
Add timeout handling to various `generate` functions across multiple files. * **cvae/main.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **flux/dreambooth.py** - Add `timeout` parameter to `generate_progress_images` function. - Implement timeout handling using `signal` module in `generate_progress_images` function. * **musicgen/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **stable_diffusion/txt2image.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llava/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llms/gguf_llm/generate.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **llms/mlx_lm/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
parent
743763bc2e
commit
f7bbe458ae
13
cvae/main.py
13
cvae/main.py
@ -4,6 +4,7 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import signal
|
||||||
|
|
||||||
import dataset
|
import dataset
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -67,7 +68,16 @@ def generate(
|
|||||||
model,
|
model,
|
||||||
out_file,
|
out_file,
|
||||||
num_samples=128,
|
num_samples=128,
|
||||||
|
timeout=None,
|
||||||
):
|
):
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
# Sample from the latent distribution:
|
# Sample from the latent distribution:
|
||||||
z = mx.random.normal([num_samples, model.num_latent_dims])
|
z = mx.random.normal([num_samples, model.num_latent_dims])
|
||||||
|
|
||||||
@ -77,6 +87,9 @@ def generate(
|
|||||||
# Save all images in a single file
|
# Save all images in a single file
|
||||||
grid_image = grid_image_from_batch(images, num_rows=8)
|
grid_image = grid_image_from_batch(images, num_rows=8)
|
||||||
grid_image.save(out_file)
|
grid_image.save(out_file)
|
||||||
|
finally:
|
||||||
|
if timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
@ -4,6 +4,7 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import signal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -16,8 +17,16 @@ from PIL import Image
|
|||||||
from flux import FluxPipeline, Trainer, load_dataset
|
from flux import FluxPipeline, Trainer, load_dataset
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args, timeout=None):
|
||||||
"""Generate images to monitor the progress of the finetuning."""
|
"""Generate images to monitor the progress of the finetuning."""
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
out_dir = Path(args.output_dir)
|
out_dir = Path(args.output_dir)
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
out_file = out_dir / f"{iteration:07d}_progress.png"
|
out_file = out_dir / f"{iteration:07d}_progress.png"
|
||||||
@ -41,6 +50,9 @@ def generate_progress_images(iteration, flux, args):
|
|||||||
# Save them to disc
|
# Save them to disc
|
||||||
im = Image.fromarray(np.array(x))
|
im = Image.fromarray(np.array(x))
|
||||||
im.save(out_file)
|
im.save(out_file)
|
||||||
|
finally:
|
||||||
|
if timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
||||||
def save_adapters(iteration, flux, args):
|
def save_adapters(iteration, flux, args):
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import codecs
|
import codecs
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import signal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import requests
|
import requests
|
||||||
@ -49,6 +50,12 @@ def parse_arguments():
|
|||||||
default=None,
|
default=None,
|
||||||
help="End of sequence token for tokenizer",
|
help="End of sequence token for tokenizer",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Timeout in seconds for the generation process.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -119,6 +126,14 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
|
|||||||
def main():
|
def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if args.timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(args.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
tokenizer_config = {}
|
tokenizer_config = {}
|
||||||
if args.eos_token is not None:
|
if args.eos_token is not None:
|
||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
@ -134,6 +149,9 @@ def main():
|
|||||||
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
|
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
|
||||||
)
|
)
|
||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
finally:
|
||||||
|
if args.timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
import signal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import models
|
import models
|
||||||
@ -13,7 +14,16 @@ def generate(
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
temp: float = 0.0,
|
temp: float = 0.0,
|
||||||
|
timeout: int = None,
|
||||||
):
|
):
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
prompt = tokenizer.encode(prompt)
|
prompt = tokenizer.encode(prompt)
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
@ -44,6 +54,9 @@ def generate(
|
|||||||
gen_tps = (len(tokens) - 1) / gen_time
|
gen_tps = (len(tokens) - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
finally:
|
||||||
|
if timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -83,4 +96,4 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
model, tokenizer = models.load(args.gguf, args.repo)
|
model, tokenizer = models.load(args.gguf, args.repo)
|
||||||
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
|
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
import signal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@ -107,6 +108,12 @@ def setup_arg_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
help="A file containing saved KV caches to avoid recomputing them",
|
help="A file containing saved KV caches to avoid recomputing them",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Timeout in seconds for the generation process.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -146,6 +153,14 @@ def main():
|
|||||||
if args.cache_limit_gb is not None:
|
if args.cache_limit_gb is not None:
|
||||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if args.timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(args.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
# Load the prompt cache and metadata if a cache file is provided
|
# Load the prompt cache and metadata if a cache file is provided
|
||||||
using_cache = args.prompt_cache_file is not None
|
using_cache = args.prompt_cache_file is not None
|
||||||
if using_cache:
|
if using_cache:
|
||||||
@ -230,6 +245,9 @@ def main():
|
|||||||
)
|
)
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
|
finally:
|
||||||
|
if args.timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
import signal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from decoder import SpeculativeDecoder
|
from decoder import SpeculativeDecoder
|
||||||
@ -21,6 +22,14 @@ def load_model(model_name: str):
|
|||||||
def main(args):
|
def main(args):
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if args.timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(args.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
spec_decoder = SpeculativeDecoder(
|
spec_decoder = SpeculativeDecoder(
|
||||||
model=load_model(args.model_name),
|
model=load_model(args.model_name),
|
||||||
draft_model=load_model(args.draft_model_name),
|
draft_model=load_model(args.draft_model_name),
|
||||||
@ -42,6 +51,9 @@ def main(args):
|
|||||||
toc = time.time()
|
toc = time.time()
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print(f"Full generation time {toc - tic:.3f}")
|
print(f"Full generation time {toc - tic:.3f}")
|
||||||
|
finally:
|
||||||
|
if args.timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -91,5 +103,11 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use regular decoding instead of speculative decoding.",
|
help="Use regular decoding instead of speculative decoding.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Timeout in seconds for the generation process.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -1,16 +1,28 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import signal
|
||||||
|
|
||||||
from utils import save_audio
|
from utils import save_audio
|
||||||
|
|
||||||
from musicgen import MusicGen
|
from musicgen import MusicGen
|
||||||
|
|
||||||
|
|
||||||
def main(text: str, output_path: str, model_name: str, max_steps: int):
|
def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None):
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
model = MusicGen.from_pretrained(model_name)
|
model = MusicGen.from_pretrained(model_name)
|
||||||
audio = model.generate(text, max_steps=max_steps)
|
audio = model.generate(text, max_steps=max_steps)
|
||||||
save_audio(output_path, audio, model.sampling_rate)
|
save_audio(output_path, audio, model.sampling_rate)
|
||||||
|
finally:
|
||||||
|
if timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -19,5 +31,6 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--text", required=False, default="happy rock")
|
parser.add_argument("--text", required=False, default="happy rock")
|
||||||
parser.add_argument("--output-path", required=False, default="0.wav")
|
parser.add_argument("--output-path", required=False, default="0.wav")
|
||||||
parser.add_argument("--max-steps", required=False, default=500, type=int)
|
parser.add_argument("--max-steps", required=False, default=500, type=int)
|
||||||
|
parser.add_argument("--timeout", required=False, default=None, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.text, args.output_path, args.model, args.max_steps)
|
main(args.text, args.output_path, args.model, args.max_steps, args.timeout)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import signal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -27,18 +28,29 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--preload-models", action="store_true")
|
parser.add_argument("--preload-models", action="store_true")
|
||||||
parser.add_argument("--output", default="out.png")
|
parser.add_argument("--output", default="out.png")
|
||||||
parser.add_argument("--seed", type=int)
|
parser.add_argument("--seed", type=int)
|
||||||
|
parser.add_argument("--timeout", type=int, default=None)
|
||||||
parser.add_argument("--verbose", "-v", action="store_true")
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Generation timed out")
|
||||||
|
|
||||||
|
if args.timeout:
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(args.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
# Load the models
|
# Load the models
|
||||||
if args.model == "sdxl":
|
if args.model == "sdxl":
|
||||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
sd.text_encoder_1,
|
||||||
|
class_predicate=lambda _, m: isinstance(m, nn.Linear),
|
||||||
)
|
)
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
sd.text_encoder_2,
|
||||||
|
class_predicate=lambda _, m: isinstance(m, nn.Linear),
|
||||||
)
|
)
|
||||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 0.0
|
args.cfg = args.cfg or 0.0
|
||||||
@ -49,7 +61,8 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
sd.text_encoder,
|
||||||
|
class_predicate=lambda _, m: isinstance(m, nn.Linear),
|
||||||
)
|
)
|
||||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 7.5
|
args.cfg = args.cfg or 7.5
|
||||||
@ -106,3 +119,6 @@ if __name__ == "__main__":
|
|||||||
if args.verbose:
|
if args.verbose:
|
||||||
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
||||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||||
|
finally:
|
||||||
|
if args.timeout:
|
||||||
|
signal.alarm(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user