From 3cc58e17fbf21d3997061c3b63994275e4dcd454 Mon Sep 17 00:00:00 2001 From: Nada Amin Date: Wed, 12 Jun 2024 10:17:06 -0400 Subject: [PATCH] Tweaks to run dspy-produced calls to the server, with gemma template. (#810) * Tweaks to run dspy-produced calls to the server, with gemma template. following comment https://github.com/stanfordnlp/dspy/issues/385#issuecomment-1998939936 can try it out with: ```sh python -m server --model mlx-community/gemma-1.1-7b-it-4bit --port 1143 ``` modulo patching the relative imports in server.py ``` -from .tokenizer_utils import TokenizerWrapper -from .utils import generate_step, load +from mlx_lm.tokenizer_utils import TokenizerWrapper +from mlx_lm.utils import generate_step, load ``` and then, ont the dspy side: ```python import dspy lm = dspy.OpenAI(model_type="chat", api_base="http://localhost:11434/v1/", api_key="not_needed", max_tokens=250) lm("hello") ``` * simpler way to validate float or int * remove logic that works around incompatible templates, too gemma specific * tweak messages for common denominator * use generate.py workaround for DBXR * put behind flag * oops * Solution to chat template issue: pass in a custom template! The template should likely adhere to the OpenAI chat model. Here is such a template for Gemma. --chat-template "{{ bos_token }}{% set extra_system = '' %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{% if role == 'system' %}{% set extra_system = extra_system + message['content'] %}{% else %}{% if role == 'user' and extra_system %}{% set message_system = 'System: ' + extra_system %}{% else %}{% set message_system = '' %}{% endif %}{{ '' + role + '\n' + message_system + message['content'] | trim + '\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}" * remove convoluted solution * Tweak for when None is provided explicitly, and must be set to [] too. For example, the outlines library provides None explicitly. * style --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/server.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 0523be50..97a9b40c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler): self.validate_model_parameters() # Get stop id sequences, if provided - stop_words = self.body.get("stop", []) + stop_words = self.body.get("stop") + stop_words = stop_words or [] stop_words = [stop_words] if isinstance(stop_words, str) else stop_words stop_id_sequences = [ self.tokenizer.encode(stop_word, add_special_tokens=False) @@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler): if not isinstance(self.max_tokens, int) or self.max_tokens < 0: raise ValueError("max_tokens must be a non-negative integer") - if not isinstance(self.temperature, float) or self.temperature < 0: + if not isinstance(self.temperature, (float, int)) or self.temperature < 0: raise ValueError("temperature must be a non-negative float") - if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1: + if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1: raise ValueError("top_p must be a float between 0 and 1") if ( - not isinstance(self.repetition_penalty, float) + not isinstance(self.repetition_penalty, (float, int)) or self.repetition_penalty < 0 ): raise ValueError("repetition_penalty must be a non-negative float") @@ -527,6 +528,18 @@ def main(): help="Set the MLX cache limit in GB", required=False, ) + parser.add_argument( + "--chat-template", + type=str, + default="", + help="Specify a chat template for the tokenizer", + required=False, + ) + parser.add_argument( + "--use-default-chat-template", + action="store_true", + help="Use the default chat template", + ) args = parser.parse_args() logging.basicConfig( @@ -540,10 +553,17 @@ def main(): # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} + if args.chat_template: + tokenizer_config["chat_template"] = args.chat_template model, tokenizer = load( args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config ) + + if args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + run(args.host, args.port, model, tokenizer)