Generate: Support Prefill Response (#1299)

* Generate: Support Prefill Prompt

python -m mlx_lm.generate \
       --model mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-4bit \
       --prompt "hello" \
       --prefill-prompt "<think>\n"

* Generate: rename prefill-prompt to prefill-response

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
madroid 2025-02-27 23:44:00 +08:00 committed by GitHub
parent 00a7379070
commit eb73549631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,6 +60,11 @@ def setup_arg_parser():
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--prefill-response",
default=None,
help="Prefill response to be used for the chat template",
)
parser.add_argument(
"--max-tokens",
"-m",
@ -219,10 +224,14 @@ def main():
messages = []
messages.append({"role": "user", "content": prompt})
has_prefill = args.prefill_response is not None
if has_prefill:
messages.append({"role": "assistant", "content": args.prefill_response})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
**template_kwargs,
)
@ -233,7 +242,8 @@ def main():
test_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)