diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 530a3483..31c06eb4 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -46,6 +46,11 @@ def setup_arg_parser(): "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--ignore-chat-template", + action="store_true", + help="Use the raw prompt without the tokenizer's chat template.", + ) return parser @@ -58,9 +63,21 @@ def main(args): tokenizer_config["eos_token"] = args.eos_token model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) + + if not args.ignore_chat_template and ( + hasattr(tokenizer, "apply_chat_template") + and tokenizer.chat_template is not None + ): + messages = [{"role": "user", "content": args.prompt}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + prompt = args.prompt + print("=" * 10) - print("Prompt:", args.prompt) - prompt = tokenizer.encode(args.prompt) + print("Prompt:", prompt) + prompt = tokenizer.encode(prompt) prompt = mx.array(prompt) tic = time.time() tokens = []