Add prompt piping (#962)

* Initial commit of --prompt-only and prompt from STDIN feature

* Switch to using --verbose instead of --prompt-only

* Fix capitalization typo

* Fix reference to changed option name

* Update exception text
This commit is contained in:
Chime Ogbuji 2024-09-03 16:29:10 -04:00 committed by GitHub
parent bf921afcbe
commit 83a209e200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
import argparse
import json
import sys
import mlx.core as mx
@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
def str2bool(string):
return string.lower() not in ["false", "f"]
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="LLM inference script")
@ -39,7 +44,9 @@ def setup_arg_parser():
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
"--prompt",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
@ -65,6 +72,12 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--verbose",
type=str2bool,
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--colorize",
action="store_true",
@ -178,7 +191,12 @@ def main():
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@ -195,6 +213,8 @@ def main():
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
@ -203,18 +223,20 @@ def main():
max_kv_size = metadata["max_kv_size"]
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
generate(
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=True,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=max_kv_size,
cache_history=cache_history,
)
if not args.verbose:
print(response)
if __name__ == "__main__":