2024-02-28 00:47:56 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-01-04 07:13:26 +08:00
|
|
|
import argparse
|
2024-08-29 13:11:45 +08:00
|
|
|
import json
|
2024-09-04 04:29:10 +08:00
|
|
|
import sys
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
from .models.cache import load_prompt_cache
|
2024-01-24 04:44:23 +08:00
|
|
|
from .utils import generate, load
|
2024-01-04 07:13:26 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
DEFAULT_PROMPT = "hello"
|
|
|
|
DEFAULT_MAX_TOKENS = 100
|
2024-10-08 11:45:51 +08:00
|
|
|
DEFAULT_TEMP = 0.0
|
2024-02-26 22:18:11 +08:00
|
|
|
DEFAULT_TOP_P = 1.0
|
2024-01-12 04:29:12 +08:00
|
|
|
DEFAULT_SEED = 0
|
2024-10-08 11:45:51 +08:00
|
|
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
2024-09-04 04:29:10 +08:00
|
|
|
def str2bool(string):
|
|
|
|
return string.lower() not in ["false", "f"]
|
|
|
|
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
def setup_arg_parser():
|
|
|
|
"""Set up and return the argument parser."""
|
|
|
|
parser = argparse.ArgumentParser(description="LLM inference script")
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
type=str,
|
2024-10-08 11:45:51 +08:00
|
|
|
help=(
|
|
|
|
"The path to the local model directory or Hugging Face repo. "
|
|
|
|
f"If no model is specified, then {DEFAULT_MODEL} is used."
|
|
|
|
),
|
|
|
|
default=None,
|
2024-01-12 04:29:12 +08:00
|
|
|
)
|
2024-02-28 23:49:25 +08:00
|
|
|
parser.add_argument(
|
2024-04-03 04:52:53 +08:00
|
|
|
"--adapter-path",
|
2024-02-28 23:49:25 +08:00
|
|
|
type=str,
|
2024-04-03 04:52:53 +08:00
|
|
|
help="Optional path for the trained adapter weights and config.",
|
2024-02-28 23:49:25 +08:00
|
|
|
)
|
2024-01-23 07:00:07 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--trust-remote-code",
|
|
|
|
action="store_true",
|
|
|
|
help="Enable trusting remote code for tokenizer",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--eos-token",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="End of sequence token for tokenizer",
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
parser.add_argument(
|
2024-09-04 04:29:10 +08:00
|
|
|
"--prompt",
|
|
|
|
default=DEFAULT_PROMPT,
|
|
|
|
help="Message to be processed by the model ('-' reads from stdin)",
|
2024-01-12 04:29:12 +08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max-tokens",
|
|
|
|
"-m",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_MAX_TOKENS,
|
|
|
|
help="Maximum number of tokens to generate",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
|
|
|
)
|
2024-02-26 22:18:11 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
2024-01-23 11:52:42 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--ignore-chat-template",
|
|
|
|
action="store_true",
|
|
|
|
help="Use the raw prompt without the tokenizer's chat template.",
|
|
|
|
)
|
2024-03-21 12:39:39 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--use-default-chat-template",
|
|
|
|
action="store_true",
|
|
|
|
help="Use the default chat template",
|
|
|
|
)
|
2024-09-04 04:29:10 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--verbose",
|
|
|
|
type=str2bool,
|
|
|
|
default=True,
|
|
|
|
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
|
|
|
)
|
2024-01-23 21:25:44 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--colorize",
|
2024-01-24 04:44:23 +08:00
|
|
|
action="store_true",
|
2024-01-23 21:25:44 +08:00
|
|
|
help="Colorize output based on T[0] probability",
|
|
|
|
)
|
2024-05-04 03:42:48 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--cache-limit-gb",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Set the MLX cache limit in GB",
|
2024-08-17 06:28:39 +08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max-kv-size",
|
|
|
|
type=int,
|
|
|
|
help="Set the maximum key-value cache size",
|
2024-08-30 06:05:17 +08:00
|
|
|
default=None,
|
2024-05-04 03:42:48 +08:00
|
|
|
)
|
2024-08-29 13:11:45 +08:00
|
|
|
parser.add_argument(
|
2024-10-08 11:45:51 +08:00
|
|
|
"--prompt-cache-file",
|
2024-08-29 13:11:45 +08:00
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="A file containing saved KV caches to avoid recomputing them",
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
return parser
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
|
2024-01-23 21:25:44 +08:00
|
|
|
def colorprint(color, s):
|
|
|
|
color_codes = {
|
2024-01-24 04:44:23 +08:00
|
|
|
"black": 30,
|
|
|
|
"red": 31,
|
|
|
|
"green": 32,
|
|
|
|
"yellow": 33,
|
|
|
|
"blue": 34,
|
|
|
|
"magenta": 35,
|
|
|
|
"cyan": 36,
|
|
|
|
"white": 39,
|
2024-01-23 21:25:44 +08:00
|
|
|
}
|
|
|
|
ccode = color_codes.get(color, 30)
|
|
|
|
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
|
|
|
|
|
|
|
|
2024-01-24 04:44:23 +08:00
|
|
|
def colorprint_by_t0(s, t0):
|
2024-01-23 21:25:44 +08:00
|
|
|
if t0 > 0.95:
|
2024-01-24 04:44:23 +08:00
|
|
|
color = "white"
|
2024-01-23 21:25:44 +08:00
|
|
|
elif t0 > 0.70:
|
2024-01-24 04:44:23 +08:00
|
|
|
color = "green"
|
2024-01-23 21:25:44 +08:00
|
|
|
elif t0 > 0.30:
|
2024-01-24 04:44:23 +08:00
|
|
|
color = "yellow"
|
2024-01-23 21:25:44 +08:00
|
|
|
else:
|
2024-01-24 04:44:23 +08:00
|
|
|
color = "red"
|
|
|
|
colorprint(color, s)
|
2024-01-23 21:25:44 +08:00
|
|
|
|
|
|
|
|
2024-04-17 07:08:49 +08:00
|
|
|
def main():
|
|
|
|
parser = setup_arg_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
mx.random.seed(args.seed)
|
2024-01-23 07:00:07 +08:00
|
|
|
|
2024-05-04 03:42:48 +08:00
|
|
|
if args.cache_limit_gb is not None:
|
|
|
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
# Load the prompt cache and metadata if a cache file is provided
|
|
|
|
using_cache = args.prompt_cache_file is not None
|
|
|
|
if using_cache:
|
|
|
|
prompt_cache, metadata = load_prompt_cache(
|
|
|
|
args.prompt_cache_file, return_metadata=True
|
|
|
|
)
|
2024-08-29 13:11:45 +08:00
|
|
|
|
2024-01-23 07:00:07 +08:00
|
|
|
# Building tokenizer_config
|
2024-08-29 13:11:45 +08:00
|
|
|
tokenizer_config = (
|
2024-10-08 11:45:51 +08:00
|
|
|
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
2024-08-29 13:11:45 +08:00
|
|
|
)
|
|
|
|
if args.trust_remote_code:
|
|
|
|
tokenizer_config["trust_remote_code"] = True
|
2024-01-23 07:00:07 +08:00
|
|
|
if args.eos_token is not None:
|
|
|
|
tokenizer_config["eos_token"] = args.eos_token
|
|
|
|
|
2024-08-29 13:11:45 +08:00
|
|
|
model_path = args.model
|
2024-10-08 11:45:51 +08:00
|
|
|
if using_cache:
|
|
|
|
if model_path is None:
|
|
|
|
model_path = metadata["model"]
|
|
|
|
elif model_path != metadata["model"]:
|
|
|
|
raise ValueError(
|
|
|
|
f"Providing a different model ({model_path}) than that "
|
|
|
|
f"used to create the prompt cache ({metadata['model']}) "
|
|
|
|
"is an error."
|
|
|
|
)
|
|
|
|
model_path = model_path or DEFAULT_MODEL
|
2024-08-29 13:11:45 +08:00
|
|
|
|
2024-02-28 23:49:25 +08:00
|
|
|
model, tokenizer = load(
|
2024-08-29 13:11:45 +08:00
|
|
|
model_path,
|
2024-05-16 23:21:26 +08:00
|
|
|
adapter_path=args.adapter_path,
|
|
|
|
tokenizer_config=tokenizer_config,
|
2024-02-28 23:49:25 +08:00
|
|
|
)
|
2024-01-23 11:52:42 +08:00
|
|
|
|
2024-03-21 12:39:39 +08:00
|
|
|
if args.use_default_chat_template:
|
|
|
|
if tokenizer.chat_template is None:
|
|
|
|
tokenizer.chat_template = tokenizer.default_chat_template
|
2024-10-08 11:45:51 +08:00
|
|
|
elif using_cache:
|
2024-08-29 13:11:45 +08:00
|
|
|
tokenizer.chat_template = metadata["chat_template"]
|
2024-03-21 12:39:39 +08:00
|
|
|
|
2024-01-23 11:52:42 +08:00
|
|
|
if not args.ignore_chat_template and (
|
|
|
|
hasattr(tokenizer, "apply_chat_template")
|
|
|
|
and tokenizer.chat_template is not None
|
|
|
|
):
|
2024-09-04 04:29:10 +08:00
|
|
|
messages = [
|
|
|
|
{
|
|
|
|
"role": "user",
|
|
|
|
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
|
|
|
}
|
|
|
|
]
|
2024-01-23 11:52:42 +08:00
|
|
|
prompt = tokenizer.apply_chat_template(
|
|
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
|
|
)
|
2024-08-29 13:11:45 +08:00
|
|
|
|
|
|
|
# Treat the prompt as a suffix assuming that the prefix is in the
|
|
|
|
# stored kv cache.
|
2024-10-08 11:45:51 +08:00
|
|
|
if using_cache:
|
2024-08-29 13:11:45 +08:00
|
|
|
test_prompt = tokenizer.apply_chat_template(
|
|
|
|
[{"role": "user", "content": "<query>"}],
|
|
|
|
tokenize=False,
|
|
|
|
add_generation_prompt=True,
|
|
|
|
)
|
|
|
|
prompt = prompt[test_prompt.index("<query>") :]
|
2024-01-23 11:52:42 +08:00
|
|
|
else:
|
|
|
|
prompt = args.prompt
|
|
|
|
|
2024-09-04 04:29:10 +08:00
|
|
|
if args.colorize and not args.verbose:
|
|
|
|
raise ValueError("Cannot use --colorize with --verbose=False")
|
2024-01-24 04:44:23 +08:00
|
|
|
formatter = colorprint_by_t0 if args.colorize else None
|
|
|
|
|
2024-09-04 04:29:10 +08:00
|
|
|
response = generate(
|
2024-02-28 05:27:08 +08:00
|
|
|
model,
|
|
|
|
tokenizer,
|
|
|
|
prompt,
|
|
|
|
args.max_tokens,
|
2024-09-04 04:29:10 +08:00
|
|
|
verbose=args.verbose,
|
2024-02-28 05:27:08 +08:00
|
|
|
formatter=formatter,
|
2024-06-04 00:04:39 +08:00
|
|
|
temp=args.temp,
|
2024-02-28 05:27:08 +08:00
|
|
|
top_p=args.top_p,
|
2024-10-08 11:45:51 +08:00
|
|
|
max_kv_size=args.max_kv_size,
|
|
|
|
prompt_cache=prompt_cache if using_cache else None,
|
2024-01-24 04:44:23 +08:00
|
|
|
)
|
2024-09-04 04:29:10 +08:00
|
|
|
if not args.verbose:
|
|
|
|
print(response)
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-04-17 07:08:49 +08:00
|
|
|
main()
|