diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 334a17d1..de02704d 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -113,37 +113,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): prompt += role_mapping.get("assistant", "") return prompt.rstrip() + def process_message_content(messages): """ - The mlx-lm server currently only supports text content. OpenAI content may - support different types of content. + Convert message content to a format suitable for `apply_chat_template`. + + The function operates on messages in place. It converts the 'content' field + to a string instead of a list of text fragments. Args: - message_list (list): A list of dictionaries, where each dictionary may have a - 'content' key containing a list of dictionaries with 'type' and 'text' keys. - - Returns: - list: A list of dictionaries similar to the input, but with the 'content' - field being a string instead of a list of text fragments. + message_list (list): A list of dictionaries, where each dictionary may + have a 'content' key containing a list of dictionaries with 'type' and + 'text' keys. Raises: ValueError: If the 'content' type is not supported or if 'text' is missing. """ - processed_messages = [] for message in messages: - message_copy = message.copy() - if "content" in message_copy and isinstance(message_copy["content"], list): + content = message["content"] + if isinstance(content, list): text_fragments = [ - fragment["text"] - for fragment in message_copy["content"] - if fragment.get("type") == "text" + fragment["text"] for fragment in content if fragment["type"] == "text" ] - if len(text_fragments) != len(message_copy["content"]): + if len(text_fragments) != len(content): raise ValueError("Only 'text' content type is supported.") - message_copy["content"] = "".join(text_fragments) - processed_messages.append(message_copy) - return processed_messages + message["content"] = "".join(text_fragments) + @dataclass class PromptCache: @@ -623,11 +619,9 @@ class APIHandler(BaseHTTPRequestHandler): self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" if self.tokenizer.chat_template: messages = body["messages"] - if isinstance(messages,list) and messages and isinstance(messages[0], dict) and "content" in messages[0] and isinstance(messages[0]["content"],list): - messages = process_message_content(messages) - body["messages"] = messages + process_message_content(messages) prompt = self.tokenizer.apply_chat_template( - body["messages"], + messages, body.get("tools", None), add_generation_prompt=True, ) diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index 16133c35..ecf95f78 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -79,11 +79,11 @@ class TestServer(unittest.TestCase): response_body = response.text self.assertIn("id", response_body) self.assertIn("choices", response_body) - + def test_handle_chat_completions_with_content_fragments(self): url = f"http://localhost:{self.port}/v1/chat/completions" chat_post_data = { - "model": "chat_model", + "model": "chat_model", "max_tokens": 10, "temperature": 0.7, "top_p": 0.85, @@ -91,19 +91,18 @@ class TestServer(unittest.TestCase): "messages": [ { "role": "system", - "content": [{"type": "text", "text": "You are a helpful assistant."}] + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ], }, - { - "role": "user", - "content": [{"type": "text", "text": "Hello!"}] - } - ] + {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, + ], } response = requests.post(url, json=chat_post_data) response_body = response.text self.assertIn("id", response_body) self.assertIn("choices", response_body) - + def test_handle_models(self): url = f"http://localhost:{self.port}/v1/models" response = requests.get(url)