2024-02-28 00:47:56 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-01-04 07:13:26 +08:00
|
|
|
import argparse
|
2024-08-29 13:11:45 +08:00
|
|
|
import json
|
2024-09-04 04:29:10 +08:00
|
|
|
import sys
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-11-01 07:59:52 +08:00
|
|
|
from .models.cache import QuantizedKVCache, load_prompt_cache
|
2024-11-24 03:47:06 +08:00
|
|
|
from .sample_utils import make_sampler
|
2024-01-24 04:44:23 +08:00
|
|
|
from .utils import generate, load
|
2024-01-04 07:13:26 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
DEFAULT_PROMPT = "hello"
|
|
|
|
DEFAULT_MAX_TOKENS = 100
|
2024-10-08 11:45:51 +08:00
|
|
|
DEFAULT_TEMP = 0.0
|
2024-02-26 22:18:11 +08:00
|
|
|
DEFAULT_TOP_P = 1.0
|
2024-11-08 08:15:24 +08:00
|
|
|
DEFAULT_MIN_P = 0.0
|
|
|
|
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
2025-03-06 22:49:35 +08:00
|
|
|
DEFAULT_SEED = None
|
2024-10-08 11:45:51 +08:00
|
|
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
2024-11-01 07:59:52 +08:00
|
|
|
DEFAULT_QUANTIZED_KV_START = 5000
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
2024-09-04 04:29:10 +08:00
|
|
|
def str2bool(string):
|
|
|
|
return string.lower() not in ["false", "f"]
|
|
|
|
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
def setup_arg_parser():
|
|
|
|
"""Set up and return the argument parser."""
|
|
|
|
parser = argparse.ArgumentParser(description="LLM inference script")
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
type=str,
|
2024-10-08 11:45:51 +08:00
|
|
|
help=(
|
|
|
|
"The path to the local model directory or Hugging Face repo. "
|
|
|
|
f"If no model is specified, then {DEFAULT_MODEL} is used."
|
|
|
|
),
|
|
|
|
default=None,
|
2024-01-12 04:29:12 +08:00
|
|
|
)
|
2024-02-28 23:49:25 +08:00
|
|
|
parser.add_argument(
|
2024-04-03 04:52:53 +08:00
|
|
|
"--adapter-path",
|
2024-02-28 23:49:25 +08:00
|
|
|
type=str,
|
2024-04-03 04:52:53 +08:00
|
|
|
help="Optional path for the trained adapter weights and config.",
|
2024-02-28 23:49:25 +08:00
|
|
|
)
|
2024-01-23 07:00:07 +08:00
|
|
|
parser.add_argument(
|
2025-01-06 14:26:05 +08:00
|
|
|
"--extra-eos-token",
|
2024-01-23 07:00:07 +08:00
|
|
|
type=str,
|
2025-01-07 02:12:07 +08:00
|
|
|
default=(),
|
2025-01-06 14:26:05 +08:00
|
|
|
nargs="+",
|
|
|
|
help="Add tokens in the list of eos tokens that stop generation.",
|
2024-01-23 07:00:07 +08:00
|
|
|
)
|
2024-11-24 03:06:26 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--system-prompt",
|
|
|
|
default=None,
|
|
|
|
help="System prompt to be used for the chat template",
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
parser.add_argument(
|
2024-09-04 04:29:10 +08:00
|
|
|
"--prompt",
|
2024-11-08 08:15:24 +08:00
|
|
|
"-p",
|
2024-09-04 04:29:10 +08:00
|
|
|
default=DEFAULT_PROMPT,
|
|
|
|
help="Message to be processed by the model ('-' reads from stdin)",
|
2024-01-12 04:29:12 +08:00
|
|
|
)
|
2025-02-27 23:44:00 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--prefill-response",
|
|
|
|
default=None,
|
|
|
|
help="Prefill response to be used for the chat template",
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--max-tokens",
|
|
|
|
"-m",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_MAX_TOKENS,
|
|
|
|
help="Maximum number of tokens to generate",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
|
|
|
)
|
2024-02-26 22:18:11 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
|
|
|
)
|
2024-11-08 08:15:24 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--min-tokens-to-keep",
|
2024-12-04 08:17:14 +08:00
|
|
|
type=int,
|
2024-11-08 08:15:24 +08:00
|
|
|
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
|
|
|
help="Minimum tokens to keep for min-p sampling.",
|
|
|
|
)
|
2025-03-06 22:49:35 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--seed",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_SEED,
|
|
|
|
help="PRNG seed",
|
|
|
|
)
|
2024-01-23 11:52:42 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--ignore-chat-template",
|
|
|
|
action="store_true",
|
|
|
|
help="Use the raw prompt without the tokenizer's chat template.",
|
|
|
|
)
|
2024-03-21 12:39:39 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--use-default-chat-template",
|
|
|
|
action="store_true",
|
|
|
|
help="Use the default chat template",
|
|
|
|
)
|
2025-02-09 07:46:15 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--chat-template-config",
|
|
|
|
help="Additional config for `apply_chat_template`. Should be a dictionary of"
|
|
|
|
" string keys to values represented as a JSON decodable string.",
|
|
|
|
default=None,
|
|
|
|
)
|
2024-09-04 04:29:10 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--verbose",
|
|
|
|
type=str2bool,
|
|
|
|
default=True,
|
|
|
|
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
|
|
|
)
|
2024-08-17 06:28:39 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--max-kv-size",
|
|
|
|
type=int,
|
|
|
|
help="Set the maximum key-value cache size",
|
2024-08-30 06:05:17 +08:00
|
|
|
default=None,
|
2024-05-04 03:42:48 +08:00
|
|
|
)
|
2024-08-29 13:11:45 +08:00
|
|
|
parser.add_argument(
|
2024-10-08 11:45:51 +08:00
|
|
|
"--prompt-cache-file",
|
2024-08-29 13:11:45 +08:00
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="A file containing saved KV caches to avoid recomputing them",
|
|
|
|
)
|
2024-11-01 07:59:52 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--kv-bits",
|
|
|
|
type=int,
|
|
|
|
help="Number of bits for KV cache quantization. "
|
|
|
|
"Defaults to no quantization.",
|
|
|
|
default=None,
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--kv-group-size",
|
|
|
|
type=int,
|
|
|
|
help="Group size for KV cache quantization.",
|
|
|
|
default=64,
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--quantized-kv-start",
|
|
|
|
help="When --kv-bits is set, start quantizing the KV cache "
|
|
|
|
"from this step onwards.",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_QUANTIZED_KV_START,
|
|
|
|
)
|
2025-01-11 07:27:08 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--draft-model",
|
|
|
|
type=str,
|
|
|
|
help="A model to be used for speculative decoding.",
|
|
|
|
default=None,
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--num-draft-tokens",
|
|
|
|
type=int,
|
|
|
|
help="Number of tokens to draft when using speculative decoding.",
|
2025-03-05 04:47:32 +08:00
|
|
|
default=3,
|
2025-01-11 07:27:08 +08:00
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
return parser
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
|
2024-04-17 07:08:49 +08:00
|
|
|
def main():
|
|
|
|
parser = setup_arg_parser()
|
|
|
|
args = parser.parse_args()
|
2025-03-06 22:49:35 +08:00
|
|
|
|
|
|
|
if args.seed is not None:
|
|
|
|
mx.random.seed(args.seed)
|
2024-01-23 07:00:07 +08:00
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
# Load the prompt cache and metadata if a cache file is provided
|
|
|
|
using_cache = args.prompt_cache_file is not None
|
|
|
|
if using_cache:
|
|
|
|
prompt_cache, metadata = load_prompt_cache(
|
2024-11-01 07:59:52 +08:00
|
|
|
args.prompt_cache_file,
|
|
|
|
return_metadata=True,
|
2024-10-08 11:45:51 +08:00
|
|
|
)
|
2024-11-01 07:59:52 +08:00
|
|
|
if isinstance(prompt_cache[0], QuantizedKVCache):
|
|
|
|
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
|
|
|
raise ValueError(
|
|
|
|
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
|
|
|
|
)
|
|
|
|
if args.kv_group_size != prompt_cache[0].group_size:
|
|
|
|
raise ValueError(
|
|
|
|
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
|
|
|
|
)
|
2024-08-29 13:11:45 +08:00
|
|
|
|
2024-01-23 07:00:07 +08:00
|
|
|
# Building tokenizer_config
|
2024-08-29 13:11:45 +08:00
|
|
|
tokenizer_config = (
|
2024-10-08 11:45:51 +08:00
|
|
|
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
2024-08-29 13:11:45 +08:00
|
|
|
)
|
2024-11-24 03:06:26 +08:00
|
|
|
tokenizer_config["trust_remote_code"] = True
|
2024-01-23 07:00:07 +08:00
|
|
|
|
2024-08-29 13:11:45 +08:00
|
|
|
model_path = args.model
|
2024-10-08 11:45:51 +08:00
|
|
|
if using_cache:
|
|
|
|
if model_path is None:
|
|
|
|
model_path = metadata["model"]
|
|
|
|
elif model_path != metadata["model"]:
|
|
|
|
raise ValueError(
|
|
|
|
f"Providing a different model ({model_path}) than that "
|
|
|
|
f"used to create the prompt cache ({metadata['model']}) "
|
|
|
|
"is an error."
|
|
|
|
)
|
|
|
|
model_path = model_path or DEFAULT_MODEL
|
2024-08-29 13:11:45 +08:00
|
|
|
|
2024-02-28 23:49:25 +08:00
|
|
|
model, tokenizer = load(
|
2024-08-29 13:11:45 +08:00
|
|
|
model_path,
|
2024-05-16 23:21:26 +08:00
|
|
|
adapter_path=args.adapter_path,
|
|
|
|
tokenizer_config=tokenizer_config,
|
2024-02-28 23:49:25 +08:00
|
|
|
)
|
2025-01-06 14:26:05 +08:00
|
|
|
for eos_token in args.extra_eos_token:
|
|
|
|
tokenizer.add_eos_token(eos_token)
|
2024-01-23 11:52:42 +08:00
|
|
|
|
2025-02-09 07:46:15 +08:00
|
|
|
template_kwargs = {}
|
|
|
|
if args.chat_template_config is not None:
|
|
|
|
template_kwargs = json.loads(args.chat_template_config)
|
|
|
|
|
2024-03-21 12:39:39 +08:00
|
|
|
if args.use_default_chat_template:
|
|
|
|
if tokenizer.chat_template is None:
|
|
|
|
tokenizer.chat_template = tokenizer.default_chat_template
|
2024-10-08 11:45:51 +08:00
|
|
|
elif using_cache:
|
2025-02-07 03:10:58 +08:00
|
|
|
tokenizer.chat_template = json.loads(metadata["chat_template"])
|
2024-03-21 12:39:39 +08:00
|
|
|
|
2024-12-19 10:46:50 +08:00
|
|
|
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
|
|
|
prompt = sys.stdin.read() if prompt == "-" else prompt
|
2025-01-04 02:50:59 +08:00
|
|
|
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
2024-11-24 03:06:26 +08:00
|
|
|
if args.system_prompt is not None:
|
|
|
|
messages = [{"role": "system", "content": args.system_prompt}]
|
|
|
|
else:
|
|
|
|
messages = []
|
2024-12-19 10:46:50 +08:00
|
|
|
messages.append({"role": "user", "content": prompt})
|
2025-02-09 07:46:15 +08:00
|
|
|
|
2025-02-27 23:44:00 +08:00
|
|
|
has_prefill = args.prefill_response is not None
|
|
|
|
if has_prefill:
|
|
|
|
messages.append({"role": "assistant", "content": args.prefill_response})
|
2024-01-23 11:52:42 +08:00
|
|
|
prompt = tokenizer.apply_chat_template(
|
2025-02-09 07:46:15 +08:00
|
|
|
messages,
|
|
|
|
tokenize=False,
|
2025-02-27 23:44:00 +08:00
|
|
|
continue_final_message=has_prefill,
|
|
|
|
add_generation_prompt=not has_prefill,
|
2025-02-09 07:46:15 +08:00
|
|
|
**template_kwargs,
|
2024-01-23 11:52:42 +08:00
|
|
|
)
|
2024-08-29 13:11:45 +08:00
|
|
|
|
|
|
|
# Treat the prompt as a suffix assuming that the prefix is in the
|
|
|
|
# stored kv cache.
|
2024-10-08 11:45:51 +08:00
|
|
|
if using_cache:
|
2024-11-24 03:06:26 +08:00
|
|
|
messages[-1]["content"] = "<query>"
|
2024-08-29 13:11:45 +08:00
|
|
|
test_prompt = tokenizer.apply_chat_template(
|
2024-11-24 03:06:26 +08:00
|
|
|
messages,
|
2024-08-29 13:11:45 +08:00
|
|
|
tokenize=False,
|
2025-02-27 23:44:00 +08:00
|
|
|
continue_final_message=has_prefill,
|
|
|
|
add_generation_prompt=not has_prefill,
|
2024-08-29 13:11:45 +08:00
|
|
|
)
|
|
|
|
prompt = prompt[test_prompt.index("<query>") :]
|
2025-01-04 02:50:59 +08:00
|
|
|
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
|
|
|
else:
|
|
|
|
prompt = tokenizer.encode(prompt)
|
|
|
|
|
2025-01-11 07:27:08 +08:00
|
|
|
if args.draft_model is not None:
|
|
|
|
draft_model, draft_tokenizer = load(args.draft_model)
|
|
|
|
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
|
|
|
|
raise ValueError("Draft model tokenizer does not match model tokenizer.")
|
|
|
|
else:
|
|
|
|
draft_model = None
|
2024-11-24 03:47:06 +08:00
|
|
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
2024-09-04 04:29:10 +08:00
|
|
|
response = generate(
|
2024-02-28 05:27:08 +08:00
|
|
|
model,
|
|
|
|
tokenizer,
|
|
|
|
prompt,
|
2024-11-24 03:47:06 +08:00
|
|
|
max_tokens=args.max_tokens,
|
2024-09-04 04:29:10 +08:00
|
|
|
verbose=args.verbose,
|
2024-11-24 03:47:06 +08:00
|
|
|
sampler=sampler,
|
2024-10-08 11:45:51 +08:00
|
|
|
max_kv_size=args.max_kv_size,
|
|
|
|
prompt_cache=prompt_cache if using_cache else None,
|
2024-11-01 07:59:52 +08:00
|
|
|
kv_bits=args.kv_bits,
|
|
|
|
kv_group_size=args.kv_group_size,
|
|
|
|
quantized_kv_start=args.quantized_kv_start,
|
2025-01-11 07:27:08 +08:00
|
|
|
draft_model=draft_model,
|
|
|
|
num_draft_tokens=args.num_draft_tokens,
|
2024-01-24 04:44:23 +08:00
|
|
|
)
|
2024-09-04 04:29:10 +08:00
|
|
|
if not args.verbose:
|
|
|
|
print(response)
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-04-17 07:08:49 +08:00
|
|
|
main()
|