mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 17:41:11 +08:00
Add prompt piping (#962)
* Initial commit of --prompt-only and prompt from STDIN feature * Switch to using --verbose instead of --prompt-only * Fix capitalization typo * Fix reference to changed option name * Update exception text
This commit is contained in:
parent
bf921afcbe
commit
83a209e200
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
return string.lower() not in ["false", "f"]
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="LLM inference script")
|
||||
@ -39,7 +44,9 @@ def setup_arg_parser():
|
||||
help="End of sequence token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
||||
"--prompt",
|
||||
default=DEFAULT_PROMPT,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
@ -65,6 +72,12 @@ def setup_arg_parser():
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--colorize",
|
||||
action="store_true",
|
||||
@ -178,7 +191,12 @@ def main():
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||
}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
@ -195,6 +213,8 @@ def main():
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
if args.colorize and not args.verbose:
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
# Determine the max kv size from the kv cache or passed arguments
|
||||
@ -203,18 +223,20 @@ def main():
|
||||
max_kv_size = metadata["max_kv_size"]
|
||||
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||
|
||||
generate(
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
verbose=True,
|
||||
verbose=args.verbose,
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
max_kv_size=max_kv_size,
|
||||
cache_history=cache_history,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user