nits + format

This commit is contained in:
Awni Hannun 2025-01-27 16:43:51 -08:00
parent 07f3d7d6bb
commit d98bf6f798
2 changed files with 24 additions and 31 deletions

View File

@ -113,37 +113,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
prompt += role_mapping.get("assistant", "") prompt += role_mapping.get("assistant", "")
return prompt.rstrip() return prompt.rstrip()
def process_message_content(messages): def process_message_content(messages):
""" """
The mlx-lm server currently only supports text content. OpenAI content may Convert message content to a format suitable for `apply_chat_template`.
support different types of content.
The function operates on messages in place. It converts the 'content' field
to a string instead of a list of text fragments.
Args: Args:
message_list (list): A list of dictionaries, where each dictionary may have a message_list (list): A list of dictionaries, where each dictionary may
'content' key containing a list of dictionaries with 'type' and 'text' keys. have a 'content' key containing a list of dictionaries with 'type' and
'text' keys.
Returns:
list: A list of dictionaries similar to the input, but with the 'content'
field being a string instead of a list of text fragments.
Raises: Raises:
ValueError: If the 'content' type is not supported or if 'text' is missing. ValueError: If the 'content' type is not supported or if 'text' is missing.
""" """
processed_messages = []
for message in messages: for message in messages:
message_copy = message.copy() content = message["content"]
if "content" in message_copy and isinstance(message_copy["content"], list): if isinstance(content, list):
text_fragments = [ text_fragments = [
fragment["text"] fragment["text"] for fragment in content if fragment["type"] == "text"
for fragment in message_copy["content"]
if fragment.get("type") == "text"
] ]
if len(text_fragments) != len(message_copy["content"]): if len(text_fragments) != len(content):
raise ValueError("Only 'text' content type is supported.") raise ValueError("Only 'text' content type is supported.")
message_copy["content"] = "".join(text_fragments) message["content"] = "".join(text_fragments)
processed_messages.append(message_copy)
return processed_messages
@dataclass @dataclass
class PromptCache: class PromptCache:
@ -623,11 +619,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template: if self.tokenizer.chat_template:
messages = body["messages"] messages = body["messages"]
if isinstance(messages,list) and messages and isinstance(messages[0], dict) and "content" in messages[0] and isinstance(messages[0]["content"],list): process_message_content(messages)
messages = process_message_content(messages)
body["messages"] = messages
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
body["messages"], messages,
body.get("tools", None), body.get("tools", None),
add_generation_prompt=True, add_generation_prompt=True,
) )

View File

@ -79,11 +79,11 @@ class TestServer(unittest.TestCase):
response_body = response.text response_body = response.text
self.assertIn("id", response_body) self.assertIn("id", response_body)
self.assertIn("choices", response_body) self.assertIn("choices", response_body)
def test_handle_chat_completions_with_content_fragments(self): def test_handle_chat_completions_with_content_fragments(self):
url = f"http://localhost:{self.port}/v1/chat/completions" url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = { chat_post_data = {
"model": "chat_model", "model": "chat_model",
"max_tokens": 10, "max_tokens": 10,
"temperature": 0.7, "temperature": 0.7,
"top_p": 0.85, "top_p": 0.85,
@ -91,19 +91,18 @@ class TestServer(unittest.TestCase):
"messages": [ "messages": [
{ {
"role": "system", "role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}] "content": [
{"type": "text", "text": "You are a helpful assistant."}
],
}, },
{ {"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
"role": "user", ],
"content": [{"type": "text", "text": "Hello!"}]
}
]
} }
response = requests.post(url, json=chat_post_data) response = requests.post(url, json=chat_post_data)
response_body = response.text response_body = response.text
self.assertIn("id", response_body) self.assertIn("id", response_body)
self.assertIn("choices", response_body) self.assertIn("choices", response_body)
def test_handle_models(self): def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models" url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url) response = requests.get(url)