diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 0523be50..97a9b40c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler): self.validate_model_parameters() # Get stop id sequences, if provided - stop_words = self.body.get("stop", []) + stop_words = self.body.get("stop") + stop_words = stop_words or [] stop_words = [stop_words] if isinstance(stop_words, str) else stop_words stop_id_sequences = [ self.tokenizer.encode(stop_word, add_special_tokens=False) @@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler): if not isinstance(self.max_tokens, int) or self.max_tokens < 0: raise ValueError("max_tokens must be a non-negative integer") - if not isinstance(self.temperature, float) or self.temperature < 0: + if not isinstance(self.temperature, (float, int)) or self.temperature < 0: raise ValueError("temperature must be a non-negative float") - if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1: + if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1: raise ValueError("top_p must be a float between 0 and 1") if ( - not isinstance(self.repetition_penalty, float) + not isinstance(self.repetition_penalty, (float, int)) or self.repetition_penalty < 0 ): raise ValueError("repetition_penalty must be a non-negative float") @@ -527,6 +528,18 @@ def main(): help="Set the MLX cache limit in GB", required=False, ) + parser.add_argument( + "--chat-template", + type=str, + default="", + help="Specify a chat template for the tokenizer", + required=False, + ) + parser.add_argument( + "--use-default-chat-template", + action="store_true", + help="Use the default chat template", + ) args = parser.parse_args() logging.basicConfig( @@ -540,10 +553,17 @@ def main(): # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} + if args.chat_template: + tokenizer_config["chat_template"] = args.chat_template model, tokenizer = load( args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config ) + + if args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + run(args.host, args.port, model, tokenizer)