default trust remote code for tokenizer, allow system prompt to be configurable

This commit is contained in:
Awni Hannun 2024-11-23 09:19:08 -08:00
parent 53569da120
commit 0593aaea89

View File

@ -41,17 +41,17 @@ def setup_arg_parser():
type=str, type=str,
help="Optional path for the trained adapter weights and config.", help="Optional path for the trained adapter weights and config.",
) )
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument( parser.add_argument(
"--eos-token", "--eos-token",
type=str, type=str,
default=None, default=None,
help="End of sequence token for tokenizer", help="End of sequence token for tokenizer",
) )
parser.add_argument(
"--system-prompt",
default=None,
help="System prompt to be used for the chat template",
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p", "-p",
@ -191,8 +191,7 @@ def main():
tokenizer_config = ( tokenizer_config = (
{} if not using_cache else json.loads(metadata["tokenizer_config"]) {} if not using_cache else json.loads(metadata["tokenizer_config"])
) )
if args.trust_remote_code: tokenizer_config["trust_remote_code"] = True
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None: if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token tokenizer_config["eos_token"] = args.eos_token
@ -224,12 +223,16 @@ def main():
hasattr(tokenizer, "apply_chat_template") hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None and tokenizer.chat_template is not None
): ):
messages = [ if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
messages = []
messages.append(
{ {
"role": "user", "role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt, "content": sys.stdin.read() if args.prompt == "-" else args.prompt,
} }
] )
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
@ -237,8 +240,9 @@ def main():
# Treat the prompt as a suffix assuming that the prefix is in the # Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache. # stored kv cache.
if using_cache: if using_cache:
messages[-1]["content"] = "<query>"
test_prompt = tokenizer.apply_chat_template( test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}], messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
) )