diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index d8f97e5e..a3e19b01 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -60,6 +60,11 @@ def setup_arg_parser(): default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--prefill-prompt", + default=None, + help="Prefill prompt to be used for the chat template", + ) parser.add_argument( "--max-tokens", "-m", @@ -219,12 +224,21 @@ def main(): messages = [] messages.append({"role": "user", "content": prompt}) - prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - **template_kwargs, - ) + if args.prefill_prompt is not None: + messages.append({"role": "assistant", "content": args.prefill_prompt}) + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + continue_final_message=True, + **template_kwargs, + ) + else: + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + **template_kwargs, + ) # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache.