diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index d8f97e5e..e40332dd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -60,6 +60,11 @@ def setup_arg_parser(): default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--prefill-response", + default=None, + help="Prefill response to be used for the chat template", + ) parser.add_argument( "--max-tokens", "-m", @@ -219,10 +224,14 @@ def main(): messages = [] messages.append({"role": "user", "content": prompt}) + has_prefill = args.prefill_response is not None + if has_prefill: + messages.append({"role": "assistant", "content": args.prefill_response}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, **template_kwargs, ) @@ -233,7 +242,8 @@ def main(): test_prompt = tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, ) prompt = prompt[test_prompt.index("") :] prompt = tokenizer.encode(prompt, add_special_tokens=False)