mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-20 10:20:46 +08:00
chore(mlx-lm): support text type content
This commit is contained in:
parent
9a3ddc3e65
commit
74ae24b883
@ -113,6 +113,37 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
prompt += role_mapping.get("assistant", "")
|
||||
return prompt.rstrip()
|
||||
|
||||
def process_message_content(messages):
|
||||
"""
|
||||
The mlx-lm server currently only supports text content. OpenAI content may
|
||||
support different types of content.
|
||||
|
||||
Args:
|
||||
message_list (list): A list of dictionaries, where each dictionary may have a
|
||||
'content' key containing a list of dictionaries with 'type' and 'text' keys.
|
||||
|
||||
Returns:
|
||||
list: A list of dictionaries similar to the input, but with the 'content'
|
||||
field being a string instead of a list of text fragments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'content' type is not supported or if 'text' is missing.
|
||||
|
||||
"""
|
||||
processed_messages = []
|
||||
for message in messages:
|
||||
message_copy = message.copy()
|
||||
if "content" in message_copy:
|
||||
flattened_text = ""
|
||||
if isinstance(message_copy["content"], list):
|
||||
for content_fragment in message_copy["content"]:
|
||||
if content_fragment["type"] == "text" and "text" in content_fragment:
|
||||
flattened_text += content_fragment["text"]
|
||||
else:
|
||||
raise ValueError("Only 'text' content type is supported.")
|
||||
message_copy["content"] = flattened_text
|
||||
processed_messages.append(message_copy)
|
||||
return processed_messages
|
||||
|
||||
@dataclass
|
||||
class PromptCache:
|
||||
@ -591,6 +622,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||
if self.tokenizer.chat_template:
|
||||
messages = body["messages"]
|
||||
if isinstance(messages,list) and messages and isinstance(messages[0], dict) and "content" in messages[0] and isinstance(messages[0]["content"],list):
|
||||
messages = process_message_content(messages)
|
||||
body["messages"] = messages
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
body.get("tools", None),
|
||||
|
@ -79,7 +79,31 @@ class TestServer(unittest.TestCase):
|
||||
response_body = response.text
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
|
||||
def test_handle_chat_completions_with_content_fragments(self):
|
||||
url = f"http://localhost:{self.port}/v1/chat/completions"
|
||||
chat_post_data = {
|
||||
"model": "chat_model",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.85,
|
||||
"repetition_penalty": 1.2,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello!"}]
|
||||
}
|
||||
]
|
||||
}
|
||||
response = requests.post(url, json=chat_post_data)
|
||||
response_body = response.text
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_models(self):
|
||||
url = f"http://localhost:{self.port}/v1/models"
|
||||
response = requests.get(url)
|
||||
|
Loading…
Reference in New Issue
Block a user