Add timeout to generate functions

Add timeout handling to various `generate` functions across multiple files.

* **cvae/main.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **flux/dreambooth.py**
  - Add `timeout` parameter to `generate_progress_images` function.
  - Implement timeout handling using `signal` module in `generate_progress_images` function.

* **musicgen/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **stable_diffusion/txt2image.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llava/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llms/gguf_llm/generate.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **llms/mlx_lm/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
锦此 2024-10-22 17:06:58 +08:00
parent 743763bc2e
commit f7bbe458ae
8 changed files with 360 additions and 239 deletions

View File

@ -4,6 +4,7 @@ import argparse
import time
from functools import partial
from pathlib import Path
import signal
import dataset
import mlx.core as mx
@ -67,16 +68,28 @@ def generate(
model,
out_file,
num_samples=128,
timeout=None,
):
# Sample from the latent distribution:
z = mx.random.normal([num_samples, model.num_latent_dims])
def handler(signum, frame):
raise TimeoutError("Generation timed out")
# Decode the latent vectors to images:
images = model.decode(z)
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
# Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file)
try:
# Sample from the latent distribution:
z = mx.random.normal([num_samples, model.num_latent_dims])
# Decode the latent vectors to images:
images = model.decode(z)
# Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file)
finally:
if timeout:
signal.alarm(0)
def main(args):

View File

@ -4,6 +4,7 @@ import argparse
import time
from functools import partial
from pathlib import Path
import signal
import mlx.core as mx
import mlx.nn as nn
@ -16,31 +17,42 @@ from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args):
def generate_progress_images(iteration, flux, args, timeout=None):
"""Generate images to monitor the progress of the finetuning."""
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
try:
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
finally:
if timeout:
signal.alarm(0)
def save_adapters(iteration, flux, args):

View File

@ -3,6 +3,7 @@
import argparse
import codecs
from pathlib import Path
import signal
import mlx.core as mx
import requests
@ -49,6 +50,12 @@ def parse_arguments():
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser.parse_args()
@ -119,21 +126,32 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
def main():
args = parse_arguments()
tokenizer_config = {}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
def handler(signum, frame):
raise TimeoutError("Generation timed out")
processor, model = load_model(args.model, tokenizer_config)
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
prompt = codecs.decode(args.prompt, "unicode_escape")
try:
tokenizer_config = {}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
processor, model = load_model(args.model, tokenizer_config)
print(prompt)
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)
prompt = codecs.decode(args.prompt, "unicode_escape")
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
print(prompt)
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__":

View File

@ -2,6 +2,7 @@
import argparse
import time
import signal
import mlx.core as mx
import models
@ -13,37 +14,49 @@ def generate(
prompt: str,
max_tokens: int,
temp: float = 0.0,
timeout: int = None,
):
prompt = tokenizer.encode(prompt)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
try:
prompt = tokenizer.encode(prompt)
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__":
@ -83,4 +96,4 @@ if __name__ == "__main__":
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = models.load(args.gguf, args.repo)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)

View File

@ -3,6 +3,7 @@
import argparse
import json
import sys
import signal
import mlx.core as mx
@ -107,6 +108,12 @@ def setup_arg_parser():
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
return parser
@ -146,90 +153,101 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True
)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
# Building tokenizer_config
tokenizer_config = (
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
model_path = args.model
if using_cache:
if model_path is None:
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
try:
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True
)
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
# Building tokenizer_config
tokenizer_config = (
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
)
if not args.verbose:
print(response)
model_path = args.model
if using_cache:
if model_path is None:
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if using_cache:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
)
if not args.verbose:
print(response)
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__":

View File

@ -1,5 +1,6 @@
import argparse
import time
import signal
import mlx.core as mx
from decoder import SpeculativeDecoder
@ -21,27 +22,38 @@ def load_model(model_name: str):
def main(args):
mx.random.seed(args.seed)
spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
tic = time.time()
print(args.prompt)
if args.regular_decode:
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
tic = time.time()
print(args.prompt)
if args.regular_decode:
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")
toc = time.time()
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")
toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")
print(f"Full generation time {toc - tic:.3f}")
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__":
@ -91,5 +103,11 @@ if __name__ == "__main__":
action="store_true",
help="Use regular decoding instead of speculative decoding.",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
args = parser.parse_args()
main(args)

View File

@ -1,16 +1,28 @@
# Copyright © 2024 Apple Inc.
import argparse
import signal
from utils import save_audio
from musicgen import MusicGen
def main(text: str, output_path: str, model_name: str, max_steps: int):
model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate)
def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate)
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__":
@ -19,5 +31,6 @@ if __name__ == "__main__":
parser.add_argument("--text", required=False, default="happy rock")
parser.add_argument("--output-path", required=False, default="0.wav")
parser.add_argument("--max-steps", required=False, default=500, type=int)
parser.add_argument("--timeout", required=False, default=None, type=int)
args = parser.parse_args()
main(args.text, args.output_path, args.model, args.max_steps)
main(args.text, args.output_path, args.model, args.max_steps, args.timeout)

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import argparse
import signal
import mlx.core as mx
import mlx.nn as nn
@ -27,82 +28,97 @@ if __name__ == "__main__":
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--seed", type=int)
parser.add_argument("--timeout", type=int, default=None)
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
# Load the models
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
nn.quantize(
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
# Load the models
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
nn.quantize(
sd.text_encoder_1,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(
sd.text_encoder_2,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
args.steps = args.steps or 2
else:
sd = StableDiffusion(
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
)
nn.quantize(
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
args.steps = args.steps or 2
else:
sd = StableDiffusion(
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
if args.quantize:
nn.quantize(
sd.text_encoder,
class_predicate=lambda _, m: isinstance(m, nn.Linear),
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50
# Ensure that models are read in memory if needed
if args.preload_models:
sd.ensure_models_are_loaded()
# Generate the latent vectors using diffusion
latents = sd.generate_latents(
args.prompt,
n_images=args.n_images,
cfg_weight=args.cfg,
num_steps=args.steps,
seed=args.seed,
negative_text=args.negative_prompt,
)
if args.quantize:
nn.quantize(
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# Ensure that models are read in memory if needed
if args.preload_models:
sd.ensure_models_are_loaded()
# The following is not necessary but it may help in memory
# constrained systems by reusing the memory kept by the unet and the text
# encoders.
if args.model == "sdxl":
del sd.text_encoder_1
del sd.text_encoder_2
else:
del sd.text_encoder
del sd.unet
del sd.sampler
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
# Generate the latent vectors using diffusion
latents = sd.generate_latents(
args.prompt,
n_images=args.n_images,
cfg_weight=args.cfg,
num_steps=args.steps,
seed=args.seed,
negative_text=args.negative_prompt,
)
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
# The following is not necessary but it may help in memory
# constrained systems by reusing the memory kept by the unet and the text
# encoders.
if args.model == "sdxl":
del sd.text_encoder_1
del sd.text_encoder_2
else:
del sd.text_encoder
del sd.unet
del sd.sampler
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
finally:
if args.timeout:
signal.alarm(0)