mlx-examples/llms/speculative_decoding/main.py
锦此 f7bbe458ae Add timeout to generate functions
Add timeout handling to various `generate` functions across multiple files.

* **cvae/main.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **flux/dreambooth.py**
  - Add `timeout` parameter to `generate_progress_images` function.
  - Implement timeout handling using `signal` module in `generate_progress_images` function.

* **musicgen/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **stable_diffusion/txt2image.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llava/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

* **llms/gguf_llm/generate.py**
  - Add `timeout` parameter to `generate` function.
  - Implement timeout handling using `signal` module in `generate` function.

* **llms/mlx_lm/generate.py**
  - Add `timeout` parameter to `main` function.
  - Implement timeout handling using `signal` module in `main` function.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
2024-10-22 17:06:58 +08:00

114 lines
3.1 KiB
Python

import argparse
import time
import signal
import mlx.core as mx
from decoder import SpeculativeDecoder
from mlx.utils import tree_unflatten
from model import Model
from transformers import T5Config
def load_model(model_name: str):
config = T5Config.from_pretrained(model_name)
model = Model(config)
weights = mx.load(f"{model_name}.npz")
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model
def main(args):
mx.random.seed(args.seed)
def handler(signum, frame):
raise TimeoutError("Generation timed out")
if args.timeout:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.timeout)
try:
spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
tic = time.time()
print(args.prompt)
if args.regular_decode:
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")
toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")
finally:
if args.timeout:
signal.alarm(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--num-draft",
type=int,
default=5,
help="Number of draft tokens to use per decoding step.",
)
parser.add_argument(
"--model-name",
help="Name of the model.",
default="t5-small",
)
parser.add_argument(
"--draft-model-name",
help="Name of the draft model.",
default="t5-small",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="PRNG seed.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--prompt",
default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.",
help="The prompt processed by the model.",
)
parser.add_argument(
"--delta",
type=float,
default=0.1,
help="Lenience for accepting the proposal tokens.",
)
parser.add_argument(
"--regular-decode",
action="store_true",
help="Use regular decoding instead of speculative decoding.",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Timeout in seconds for the generation process.",
)
args = parser.parse_args()
main(args)