add an option to apply the tokenizer chat template (#338)

* add an option to apply the tokenizer chat template

* fix the option to apply the tokenizer chat template

* better error messages for chat template issues

* apply the chat template by default when possible

* nit in comment'

* rebase

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Baptiste Canton 2024-01-23 04:52:42 +01:00 committed by GitHub
parent 8022083979
commit 42672f5446
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,11 @@ def setup_arg_parser():
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
return parser
@ -58,9 +63,21 @@ def main(args):
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
prompt = args.prompt
print("=" * 10)
print("Prompt:", args.prompt)
prompt = tokenizer.encode(args.prompt)
print("Prompt:", prompt)
prompt = tokenizer.encode(prompt)
prompt = mx.array(prompt)
tic = time.time()
tokens = []