From 74ae24b883c88d574f2ca8256a6d10ec3adac083 Mon Sep 17 00:00:00 2001
From: anchen
Date: Mon, 27 Jan 2025 00:00:10 +1100
Subject: [PATCH] chore(mlx-lm): support text type content
---
llms/mlx_lm/server.py | 35 +++++++++++++++++++++++++++++++++++
llms/tests/test_server.py | 26 +++++++++++++++++++++++++-
2 files changed, 60 insertions(+), 1 deletion(-)
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index 4523e3ae..4124dfa0 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -113,6 +113,37 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
prompt += role_mapping.get("assistant", "")
return prompt.rstrip()
+def process_message_content(messages):
+ """
+ The mlx-lm server currently only supports text content. OpenAI content may
+ support different types of content.
+
+ Args:
+ message_list (list): A list of dictionaries, where each dictionary may have a
+ 'content' key containing a list of dictionaries with 'type' and 'text' keys.
+
+ Returns:
+ list: A list of dictionaries similar to the input, but with the 'content'
+ field being a string instead of a list of text fragments.
+
+ Raises:
+ ValueError: If the 'content' type is not supported or if 'text' is missing.
+
+ """
+ processed_messages = []
+ for message in messages:
+ message_copy = message.copy()
+ if "content" in message_copy:
+ flattened_text = ""
+ if isinstance(message_copy["content"], list):
+ for content_fragment in message_copy["content"]:
+ if content_fragment["type"] == "text" and "text" in content_fragment:
+ flattened_text += content_fragment["text"]
+ else:
+ raise ValueError("Only 'text' content type is supported.")
+ message_copy["content"] = flattened_text
+ processed_messages.append(message_copy)
+ return processed_messages
@dataclass
class PromptCache:
@@ -591,6 +622,10 @@ class APIHandler(BaseHTTPRequestHandler):
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template:
+ messages = body["messages"]
+ if isinstance(messages,list) and messages and isinstance(messages[0], dict) and "content" in messages[0] and isinstance(messages[0]["content"],list):
+ messages = process_message_content(messages)
+ body["messages"] = messages
prompt = self.tokenizer.apply_chat_template(
body["messages"],
body.get("tools", None),
diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py
index ad17554d..16133c35 100644
--- a/llms/tests/test_server.py
+++ b/llms/tests/test_server.py
@@ -79,7 +79,31 @@ class TestServer(unittest.TestCase):
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
-
+
+ def test_handle_chat_completions_with_content_fragments(self):
+ url = f"http://localhost:{self.port}/v1/chat/completions"
+ chat_post_data = {
+ "model": "chat_model",
+ "max_tokens": 10,
+ "temperature": 0.7,
+ "top_p": 0.85,
+ "repetition_penalty": 1.2,
+ "messages": [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
+ },
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "Hello!"}]
+ }
+ ]
+ }
+ response = requests.post(url, json=chat_post_data)
+ response_body = response.text
+ self.assertIn("id", response_body)
+ self.assertIn("choices", response_body)
+
def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)