Merge branch 'ml-explore:main' into adding-GRPO-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-28 11:18:32 +01:00 committed by GitHub
commit a04eb02257
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,6 +60,11 @@ def setup_arg_parser():
default=DEFAULT_PROMPT, default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument(
"--prefill-response",
default=None,
help="Prefill response to be used for the chat template",
)
parser.add_argument( parser.add_argument(
"--max-tokens", "--max-tokens",
"-m", "-m",
@ -219,10 +224,14 @@ def main():
messages = [] messages = []
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
has_prefill = args.prefill_response is not None
if has_prefill:
messages.append({"role": "assistant", "content": args.prefill_response})
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
**template_kwargs, **template_kwargs,
) )
@ -233,7 +242,8 @@ def main():
test_prompt = tokenizer.apply_chat_template( test_prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False) prompt = tokenizer.encode(prompt, add_special_tokens=False)