From 83a209e200eef505a688997f64afdff2c6e0d363 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Tue, 3 Sep 2024 16:29:10 -0400 Subject: [PATCH] Add prompt piping (#962) * Initial commit of --prompt-only and prompt from STDIN feature * Switch to using --verbose instead of --prompt-only * Fix capitalization typo * Fix reference to changed option name * Update exception text --- llms/mlx_lm/generate.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index f37037b6..537bd853 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -2,6 +2,7 @@ import argparse import json +import sys import mlx.core as mx @@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +def str2bool(string): + return string.lower() not in ["false", "f"] + + def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="LLM inference script") @@ -39,7 +44,9 @@ def setup_arg_parser(): help="End of sequence token for tokenizer", ) parser.add_argument( - "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" + "--prompt", + default=DEFAULT_PROMPT, + help="Message to be processed by the model ('-' reads from stdin)", ) parser.add_argument( "--max-tokens", @@ -65,6 +72,12 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) + parser.add_argument( + "--verbose", + type=str2bool, + default=True, + help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", + ) parser.add_argument( "--colorize", action="store_true", @@ -178,7 +191,12 @@ def main(): hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None ): - messages = [{"role": "user", "content": args.prompt}] + messages = [ + { + "role": "user", + "content": sys.stdin.read() if args.prompt == "-" else args.prompt, + } + ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -195,6 +213,8 @@ def main(): else: prompt = args.prompt + if args.colorize and not args.verbose: + raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None # Determine the max kv size from the kv cache or passed arguments @@ -203,18 +223,20 @@ def main(): max_kv_size = metadata["max_kv_size"] max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None - generate( + response = generate( model, tokenizer, prompt, args.max_tokens, - verbose=True, + verbose=args.verbose, formatter=formatter, temp=args.temp, top_p=args.top_p, max_kv_size=max_kv_size, cache_history=cache_history, ) + if not args.verbose: + print(response) if __name__ == "__main__":