From 109eb4e9427f3784dd7620071b754dceb6fa88cc Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 27 Feb 2025 07:39:15 -0800 Subject: [PATCH] nits --- llms/mlx_lm/generate.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 69a5d975..e40332dd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -224,21 +224,16 @@ def main(): messages = [] messages.append({"role": "user", "content": prompt}) - if args.prefill_response is not None: + has_prefill = args.prefill_response is not None + if has_prefill: messages.append({"role": "assistant", "content": args.prefill_response}) - prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - continue_final_message=True, - **template_kwargs, - ) - else: - prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - **template_kwargs, - ) + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, + **template_kwargs, + ) # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. @@ -247,7 +242,8 @@ def main(): test_prompt = tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, ) prompt = prompt[test_prompt.index("") :] prompt = tokenizer.encode(prompt, add_special_tokens=False)