mlx-examples/llms/gguf_llm/generate.py
锦此 f7bbe458ae Add timeout to generate functions
Add timeout handling to various `generate` functions across multiple files.

* **cvae/main.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **flux/dreambooth.py**
  - Add `timeout` parameter to `generate_progress_images` function.
  - Implement timeout handling using `signal` module in `generate_progress_images` function.

* **musicgen/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **stable_diffusion/txt2image.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llava/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llms/gguf_llm/generate.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **llms/mlx_lm/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
2024-10-22 17:06:58 +08:00

100 lines
2.6 KiB
Python

# Copyright © 2023 Apple Inc.
import argparse
import time
import signal
import mlx.core as mx
import models
def generate(
model: models.Model,
tokenizer: models.GGUFTokenizer,
prompt: str,
max_tokens: int,
temp: float = 0.0,
timeout: int = None,
):
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
prompt = tokenizer.encode(prompt)
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
finally:
if timeout:
signal.alarm(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Inference script")
parser.add_argument(
"--gguf",
type=str,
help="The GGUF file to load (and optionally download).",
)
parser.add_argument(
"--repo",
type=str,
default=None,
help="The Hugging Face repo if downloading from the Hub.",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = models.load(args.gguf, args.repo)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)