From 0593aaea89f56b1539cb3b75420033913729b37d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 23 Nov 2024 09:19:08 -0800 Subject: [PATCH] default trust remote code for tokenizer, allow system prompt to be configurable --- llms/mlx_lm/generate.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 51169def..de5c5719 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -41,17 +41,17 @@ def setup_arg_parser(): type=str, help="Optional path for the trained adapter weights and config.", ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Enable trusting remote code for tokenizer", - ) parser.add_argument( "--eos-token", type=str, default=None, help="End of sequence token for tokenizer", ) + parser.add_argument( + "--system-prompt", + default=None, + help="System prompt to be used for the chat template", + ) parser.add_argument( "--prompt", "-p", @@ -191,8 +191,7 @@ def main(): tokenizer_config = ( {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) - if args.trust_remote_code: - tokenizer_config["trust_remote_code"] = True + tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token @@ -224,12 +223,16 @@ def main(): hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None ): - messages = [ + if args.system_prompt is not None: + messages = [{"role": "system", "content": args.system_prompt}] + else: + messages = [] + messages.append( { "role": "user", "content": sys.stdin.read() if args.prompt == "-" else args.prompt, } - ] + ) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -237,8 +240,9 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. if using_cache: + messages[-1]["content"] = "" test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], + messages, tokenize=False, add_generation_prompt=True, )