2024-01-04 07:13:26 +08:00
|
|
|
import argparse
|
|
|
|
import time
|
|
|
|
|
|
|
|
import mlx.core as mx
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
from .utils import generate_step, load
|
2024-01-04 07:13:26 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
DEFAULT_MODEL_PATH = "mlx_model"
|
|
|
|
DEFAULT_PROMPT = "hello"
|
|
|
|
DEFAULT_MAX_TOKENS = 100
|
|
|
|
DEFAULT_TEMP = 0.6
|
|
|
|
DEFAULT_SEED = 0
|
|
|
|
|
|
|
|
|
|
|
|
def setup_arg_parser():
|
|
|
|
"""Set up and return the argument parser."""
|
|
|
|
parser = argparse.ArgumentParser(description="LLM inference script")
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
type=str,
|
|
|
|
default="mlx_model",
|
|
|
|
help="The path to the local model directory or Hugging Face repo.",
|
|
|
|
)
|
2024-01-23 07:00:07 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--trust-remote-code",
|
|
|
|
action="store_true",
|
|
|
|
help="Enable trusting remote code for tokenizer",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--eos-token",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="End of sequence token for tokenizer",
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max-tokens",
|
|
|
|
"-m",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_MAX_TOKENS,
|
|
|
|
help="Maximum number of tokens to generate",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
|
|
|
)
|
|
|
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
2024-01-23 11:52:42 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--ignore-chat-template",
|
|
|
|
action="store_true",
|
|
|
|
help="Use the raw prompt without the tokenizer's chat template.",
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
return parser
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
def main(args):
|
|
|
|
mx.random.seed(args.seed)
|
2024-01-23 07:00:07 +08:00
|
|
|
|
|
|
|
# Building tokenizer_config
|
|
|
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
|
|
|
if args.eos_token is not None:
|
|
|
|
tokenizer_config["eos_token"] = args.eos_token
|
|
|
|
|
|
|
|
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
|
2024-01-23 11:52:42 +08:00
|
|
|
|
|
|
|
if not args.ignore_chat_template and (
|
|
|
|
hasattr(tokenizer, "apply_chat_template")
|
|
|
|
and tokenizer.chat_template is not None
|
|
|
|
):
|
|
|
|
messages = [{"role": "user", "content": args.prompt}]
|
|
|
|
prompt = tokenizer.apply_chat_template(
|
|
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
prompt = args.prompt
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
print("=" * 10)
|
2024-01-23 11:52:42 +08:00
|
|
|
print("Prompt:", prompt)
|
|
|
|
prompt = tokenizer.encode(prompt)
|
2024-01-12 04:29:12 +08:00
|
|
|
prompt = mx.array(prompt)
|
2024-01-04 07:13:26 +08:00
|
|
|
tic = time.time()
|
|
|
|
tokens = []
|
|
|
|
skip = 0
|
2024-01-13 02:25:56 +08:00
|
|
|
for token, n in zip(
|
|
|
|
generate_step(prompt, model, args.temp), range(args.max_tokens)
|
|
|
|
):
|
2024-01-04 07:13:26 +08:00
|
|
|
if token == tokenizer.eos_token_id:
|
|
|
|
break
|
|
|
|
if n == 0:
|
|
|
|
prompt_time = time.time() - tic
|
|
|
|
tic = time.time()
|
|
|
|
tokens.append(token.item())
|
|
|
|
s = tokenizer.decode(tokens)
|
|
|
|
print(s[skip:], end="", flush=True)
|
|
|
|
skip = len(s)
|
|
|
|
print(tokenizer.decode(tokens)[skip:], flush=True)
|
|
|
|
gen_time = time.time() - tic
|
|
|
|
print("=" * 10)
|
2024-01-05 11:14:13 +08:00
|
|
|
if len(tokens) == 0:
|
|
|
|
print("No tokens generated for this prompt")
|
|
|
|
return
|
2024-01-04 07:13:26 +08:00
|
|
|
prompt_tps = prompt.size / prompt_time
|
|
|
|
gen_tps = (len(tokens) - 1) / gen_time
|
|
|
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
|
|
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-01-12 04:29:12 +08:00
|
|
|
parser = setup_arg_parser()
|
2024-01-04 07:13:26 +08:00
|
|
|
args = parser.parse_args()
|
2024-01-12 04:29:12 +08:00
|
|
|
main(args)
|