From 74ae24b883c88d574f2ca8256a6d10ec3adac083 Mon Sep 17 00:00:00 2001 From: anchen Date: Mon, 27 Jan 2025 00:00:10 +1100 Subject: [PATCH] chore(mlx-lm): support text type content --- llms/mlx_lm/server.py | 35 +++++++++++++++++++++++++++++++++++ llms/tests/test_server.py | 26 +++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 4523e3ae..4124dfa0 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -113,6 +113,37 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): prompt += role_mapping.get("assistant", "") return prompt.rstrip() +def process_message_content(messages): + """ + The mlx-lm server currently only supports text content. OpenAI content may + support different types of content. + + Args: + message_list (list): A list of dictionaries, where each dictionary may have a + 'content' key containing a list of dictionaries with 'type' and 'text' keys. + + Returns: + list: A list of dictionaries similar to the input, but with the 'content' + field being a string instead of a list of text fragments. + + Raises: + ValueError: If the 'content' type is not supported or if 'text' is missing. + + """ + processed_messages = [] + for message in messages: + message_copy = message.copy() + if "content" in message_copy: + flattened_text = "" + if isinstance(message_copy["content"], list): + for content_fragment in message_copy["content"]: + if content_fragment["type"] == "text" and "text" in content_fragment: + flattened_text += content_fragment["text"] + else: + raise ValueError("Only 'text' content type is supported.") + message_copy["content"] = flattened_text + processed_messages.append(message_copy) + return processed_messages @dataclass class PromptCache: @@ -591,6 +622,10 @@ class APIHandler(BaseHTTPRequestHandler): self.request_id = f"chatcmpl-{uuid.uuid4()}" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" if self.tokenizer.chat_template: + messages = body["messages"] + if isinstance(messages,list) and messages and isinstance(messages[0], dict) and "content" in messages[0] and isinstance(messages[0]["content"],list): + messages = process_message_content(messages) + body["messages"] = messages prompt = self.tokenizer.apply_chat_template( body["messages"], body.get("tools", None), diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index ad17554d..16133c35 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -79,7 +79,31 @@ class TestServer(unittest.TestCase): response_body = response.text self.assertIn("id", response_body) self.assertIn("choices", response_body) - + + def test_handle_chat_completions_with_content_fragments(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + chat_post_data = { + "model": "chat_model", + "max_tokens": 10, + "temperature": 0.7, + "top_p": 0.85, + "repetition_penalty": 1.2, + "messages": [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}] + }, + { + "role": "user", + "content": [{"type": "text", "text": "Hello!"}] + } + ] + } + response = requests.post(url, json=chat_post_data) + response_body = response.text + self.assertIn("id", response_body) + self.assertIn("choices", response_body) + def test_handle_models(self): url = f"http://localhost:{self.port}/v1/models" response = requests.get(url)