This commit is contained in:
Awni Hannun 2025-02-27 07:39:15 -08:00
parent 9f9da6af23
commit 109eb4e942

View File

@ -224,19 +224,14 @@ def main():
messages = [] messages = []
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if args.prefill_response is not None: has_prefill = args.prefill_response is not None
if has_prefill:
messages.append({"role": "assistant", "content": args.prefill_response}) messages.append({"role": "assistant", "content": args.prefill_response})
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
continue_final_message=True, continue_final_message=has_prefill,
**template_kwargs, add_generation_prompt=not has_prefill,
)
else:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
**template_kwargs, **template_kwargs,
) )
@ -247,7 +242,8 @@ def main():
test_prompt = tokenizer.apply_chat_template( test_prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False) prompt = tokenizer.encode(prompt, add_special_tokens=False)