mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00

Add timeout handling to various `generate` functions across multiple files. * **cvae/main.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **flux/dreambooth.py** - Add `timeout` parameter to `generate_progress_images` function. - Implement timeout handling using `signal` module in `generate_progress_images` function. * **musicgen/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **stable_diffusion/txt2image.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llava/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llms/gguf_llm/generate.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **llms/mlx_lm/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
import argparse
|
|
import time
|
|
import signal
|
|
|
|
import mlx.core as mx
|
|
from decoder import SpeculativeDecoder
|
|
from mlx.utils import tree_unflatten
|
|
from model import Model
|
|
from transformers import T5Config
|
|
|
|
|
|
def load_model(model_name: str):
|
|
config = T5Config.from_pretrained(model_name)
|
|
model = Model(config)
|
|
weights = mx.load(f"{model_name}.npz")
|
|
weights = tree_unflatten(list(weights.items()))
|
|
model.update(weights)
|
|
mx.eval(model.parameters())
|
|
return model
|
|
|
|
|
|
def main(args):
|
|
mx.random.seed(args.seed)
|
|
|
|
def handler(signum, frame):
|
|
raise TimeoutError("Generation timed out")
|
|
|
|
if args.timeout:
|
|
signal.signal(signal.SIGALRM, handler)
|
|
signal.alarm(args.timeout)
|
|
|
|
try:
|
|
spec_decoder = SpeculativeDecoder(
|
|
model=load_model(args.model_name),
|
|
draft_model=load_model(args.draft_model_name),
|
|
tokenizer=args.model_name,
|
|
delta=args.delta,
|
|
num_draft=args.num_draft,
|
|
)
|
|
|
|
tic = time.time()
|
|
print(args.prompt)
|
|
if args.regular_decode:
|
|
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
|
|
else:
|
|
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
|
|
print("=" * 10)
|
|
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
|
|
print(f"Decoding steps {stats['n_steps']}.")
|
|
|
|
toc = time.time()
|
|
print("=" * 10)
|
|
print(f"Full generation time {toc - tic:.3f}")
|
|
finally:
|
|
if args.timeout:
|
|
signal.alarm(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
|
parser.add_argument(
|
|
"--num-draft",
|
|
type=int,
|
|
default=5,
|
|
help="Number of draft tokens to use per decoding step.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-name",
|
|
help="Name of the model.",
|
|
default="t5-small",
|
|
)
|
|
parser.add_argument(
|
|
"--draft-model-name",
|
|
help="Name of the draft model.",
|
|
default="t5-small",
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=0,
|
|
help="PRNG seed.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-tokens",
|
|
"-m",
|
|
type=int,
|
|
default=100,
|
|
help="Maximum number of tokens to generate.",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.",
|
|
help="The prompt processed by the model.",
|
|
)
|
|
parser.add_argument(
|
|
"--delta",
|
|
type=float,
|
|
default=0.1,
|
|
help="Lenience for accepting the proposal tokens.",
|
|
)
|
|
parser.add_argument(
|
|
"--regular-decode",
|
|
action="store_true",
|
|
help="Use regular decoding instead of speculative decoding.",
|
|
)
|
|
parser.add_argument(
|
|
"--timeout",
|
|
type=int,
|
|
default=None,
|
|
help="Timeout in seconds for the generation process.",
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|