This commit is contained in:
Awni Hannun 2025-02-27 07:39:15 -08:00
parent 9f9da6af23
commit 109eb4e942

View File

@ -224,21 +224,16 @@ def main():
messages = []
messages.append({"role": "user", "content": prompt})
if args.prefill_response is not None:
has_prefill = args.prefill_response is not None
if has_prefill:
messages.append({"role": "assistant", "content": args.prefill_response})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
continue_final_message=True,
**template_kwargs,
)
else:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
**template_kwargs,
)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
**template_kwargs,
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
@ -247,7 +242,8 @@ def main():
test_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)