mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	Add prompt piping (#962)
* Initial commit of --prompt-only and prompt from STDIN feature * Switch to using --verbose instead of --prompt-only * Fix capitalization typo * Fix reference to changed option name * Update exception text
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| import argparse | ||||
| import json | ||||
| import sys | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| @@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0 | ||||
| DEFAULT_SEED = 0 | ||||
|  | ||||
|  | ||||
| def str2bool(string): | ||||
|     return string.lower() not in ["false", "f"] | ||||
|  | ||||
|  | ||||
| def setup_arg_parser(): | ||||
|     """Set up and return the argument parser.""" | ||||
|     parser = argparse.ArgumentParser(description="LLM inference script") | ||||
| @@ -39,7 +44,9 @@ def setup_arg_parser(): | ||||
|         help="End of sequence token for tokenizer", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" | ||||
|         "--prompt", | ||||
|         default=DEFAULT_PROMPT, | ||||
|         help="Message to be processed by the model ('-' reads from stdin)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max-tokens", | ||||
| @@ -65,6 +72,12 @@ def setup_arg_parser(): | ||||
|         action="store_true", | ||||
|         help="Use the default chat template", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--verbose", | ||||
|         type=str2bool, | ||||
|         default=True, | ||||
|         help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--colorize", | ||||
|         action="store_true", | ||||
| @@ -178,7 +191,12 @@ def main(): | ||||
|         hasattr(tokenizer, "apply_chat_template") | ||||
|         and tokenizer.chat_template is not None | ||||
|     ): | ||||
|         messages = [{"role": "user", "content": args.prompt}] | ||||
|         messages = [ | ||||
|             { | ||||
|                 "role": "user", | ||||
|                 "content": sys.stdin.read() if args.prompt == "-" else args.prompt, | ||||
|             } | ||||
|         ] | ||||
|         prompt = tokenizer.apply_chat_template( | ||||
|             messages, tokenize=False, add_generation_prompt=True | ||||
|         ) | ||||
| @@ -195,6 +213,8 @@ def main(): | ||||
|     else: | ||||
|         prompt = args.prompt | ||||
|  | ||||
|     if args.colorize and not args.verbose: | ||||
|         raise ValueError("Cannot use --colorize with --verbose=False") | ||||
|     formatter = colorprint_by_t0 if args.colorize else None | ||||
|  | ||||
|     # Determine the max kv size from the kv cache or passed arguments | ||||
| @@ -203,18 +223,20 @@ def main(): | ||||
|         max_kv_size = metadata["max_kv_size"] | ||||
|         max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None | ||||
|  | ||||
|     generate( | ||||
|     response = generate( | ||||
|         model, | ||||
|         tokenizer, | ||||
|         prompt, | ||||
|         args.max_tokens, | ||||
|         verbose=True, | ||||
|         verbose=args.verbose, | ||||
|         formatter=formatter, | ||||
|         temp=args.temp, | ||||
|         top_p=args.top_p, | ||||
|         max_kv_size=max_kv_size, | ||||
|         cache_history=cache_history, | ||||
|     ) | ||||
|     if not args.verbose: | ||||
|         print(response) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Chime Ogbuji
					Chime Ogbuji