fix encoding with special tokens + chat template (#1189)

This commit is contained in:
Awni Hannun
2025-01-03 10:50:59 -08:00
committed by GitHub
parent 3a58c36109
commit c4833a2f55
13 changed files with 95 additions and 97 deletions

View File

@@ -190,10 +190,7 @@ def main():
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if not args.ignore_chat_template and tokenizer.chat_template is not None:
if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
@@ -214,6 +211,10 @@ def main():
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,