diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index e7994750..d8f97e5e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -93,6 +93,12 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) + parser.add_argument( + "--chat-template-config", + help="Additional config for `apply_chat_template`. Should be a dictionary of" + " string keys to values represented as a JSON decodable string.", + default=None, + ) parser.add_argument( "--verbose", type=str2bool, @@ -149,7 +155,6 @@ def setup_arg_parser(): def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) # Load the prompt cache and metadata if a cache file is provided @@ -195,6 +200,10 @@ def main(): for eos_token in args.extra_eos_token: tokenizer.add_eos_token(eos_token) + template_kwargs = {} + if args.chat_template_config is not None: + template_kwargs = json.loads(args.chat_template_config) + if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template @@ -209,8 +218,12 @@ def main(): else: messages = [] messages.append({"role": "user", "content": prompt}) + prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, + tokenize=False, + add_generation_prompt=True, + **template_kwargs, ) # Treat the prompt as a suffix assuming that the prefix is in the