2024-01-24 07:41:54 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import time
|
2024-10-22 17:06:58 +08:00
|
|
|
import signal
|
2024-01-24 07:41:54 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import models
|
|
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
model: models.Model,
|
|
|
|
tokenizer: models.GGUFTokenizer,
|
|
|
|
prompt: str,
|
|
|
|
max_tokens: int,
|
|
|
|
temp: float = 0.0,
|
2024-10-22 17:06:58 +08:00
|
|
|
timeout: int = None,
|
2024-01-24 07:41:54 +08:00
|
|
|
):
|
2024-10-22 17:06:58 +08:00
|
|
|
def handler(signum, frame):
|
|
|
|
raise TimeoutError("Generation timed out")
|
2024-01-24 07:41:54 +08:00
|
|
|
|
2024-10-22 17:06:58 +08:00
|
|
|
if timeout:
|
|
|
|
signal.signal(signal.SIGALRM, handler)
|
|
|
|
signal.alarm(timeout)
|
2024-01-24 07:41:54 +08:00
|
|
|
|
2024-10-22 17:06:58 +08:00
|
|
|
try:
|
|
|
|
prompt = tokenizer.encode(prompt)
|
2024-01-24 07:41:54 +08:00
|
|
|
|
2024-10-22 17:06:58 +08:00
|
|
|
tic = time.time()
|
|
|
|
tokens = []
|
|
|
|
skip = 0
|
|
|
|
for token, n in zip(
|
|
|
|
models.generate(prompt, model, args.temp),
|
|
|
|
range(args.max_tokens),
|
|
|
|
):
|
|
|
|
if token == tokenizer.eos_token_id:
|
|
|
|
break
|
|
|
|
|
|
|
|
if n == 0:
|
|
|
|
prompt_time = time.time() - tic
|
|
|
|
tic = time.time()
|
|
|
|
|
|
|
|
tokens.append(token.item())
|
|
|
|
s = tokenizer.decode(tokens)
|
|
|
|
print(s[skip:], end="", flush=True)
|
|
|
|
skip = len(s)
|
|
|
|
print(tokenizer.decode(tokens)[skip:], flush=True)
|
|
|
|
gen_time = time.time() - tic
|
|
|
|
print("=" * 10)
|
|
|
|
if len(tokens) == 0:
|
|
|
|
print("No tokens generated for this prompt")
|
|
|
|
return
|
|
|
|
prompt_tps = prompt.size / prompt_time
|
|
|
|
gen_tps = (len(tokens) - 1) / gen_time
|
|
|
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
|
|
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
|
|
|
finally:
|
|
|
|
if timeout:
|
|
|
|
signal.alarm(0)
|
2024-01-24 07:41:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Inference script")
|
|
|
|
parser.add_argument(
|
|
|
|
"--gguf",
|
|
|
|
type=str,
|
|
|
|
help="The GGUF file to load (and optionally download).",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--repo",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="The Hugging Face repo if downloading from the Hub.",
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
"--prompt",
|
|
|
|
help="The message to be processed by the model",
|
|
|
|
default="In the beginning the Universe was created.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max-tokens",
|
|
|
|
"-m",
|
|
|
|
type=int,
|
|
|
|
default=100,
|
|
|
|
help="Maximum number of tokens to generate",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--temp",
|
|
|
|
help="The sampling temperature.",
|
|
|
|
type=float,
|
|
|
|
default=0.0,
|
|
|
|
)
|
|
|
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
mx.random.seed(args.seed)
|
|
|
|
model, tokenizer = models.load(args.gguf, args.repo)
|
2024-10-22 17:06:58 +08:00
|
|
|
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)
|