mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	chore(mlx-lm): support text type content in messages (#1225)
* chore(mlx-lm): support text type content * chore: optimize the messagef content processing * nits + format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): | ||||
|     return prompt.rstrip() | ||||
|  | ||||
|  | ||||
| def process_message_content(messages): | ||||
|     """ | ||||
|     Convert message content to a format suitable for `apply_chat_template`. | ||||
|  | ||||
|     The function operates on messages in place. It converts the 'content' field | ||||
|     to a string instead of a list of text fragments. | ||||
|  | ||||
|     Args: | ||||
|         message_list (list): A list of dictionaries, where each dictionary may | ||||
|           have a 'content' key containing a list of dictionaries with 'type' and | ||||
|           'text' keys. | ||||
|  | ||||
|     Raises: | ||||
|         ValueError: If the 'content' type is not supported or if 'text' is missing. | ||||
|  | ||||
|     """ | ||||
|     for message in messages: | ||||
|         content = message["content"] | ||||
|         if isinstance(content, list): | ||||
|             text_fragments = [ | ||||
|                 fragment["text"] for fragment in content if fragment["type"] == "text" | ||||
|             ] | ||||
|             if len(text_fragments) != len(content): | ||||
|                 raise ValueError("Only 'text' content type is supported.") | ||||
|             message["content"] = "".join(text_fragments) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class PromptCache: | ||||
|     cache: List[Any] = field(default_factory=list) | ||||
| @@ -591,8 +618,10 @@ class APIHandler(BaseHTTPRequestHandler): | ||||
|         self.request_id = f"chatcmpl-{uuid.uuid4()}" | ||||
|         self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" | ||||
|         if self.tokenizer.chat_template: | ||||
|             messages = body["messages"] | ||||
|             process_message_content(messages) | ||||
|             prompt = self.tokenizer.apply_chat_template( | ||||
|                 body["messages"], | ||||
|                 messages, | ||||
|                 body.get("tools", None), | ||||
|                 add_generation_prompt=True, | ||||
|             ) | ||||
|   | ||||
| @@ -80,6 +80,29 @@ class TestServer(unittest.TestCase): | ||||
|         self.assertIn("id", response_body) | ||||
|         self.assertIn("choices", response_body) | ||||
|  | ||||
|     def test_handle_chat_completions_with_content_fragments(self): | ||||
|         url = f"http://localhost:{self.port}/v1/chat/completions" | ||||
|         chat_post_data = { | ||||
|             "model": "chat_model", | ||||
|             "max_tokens": 10, | ||||
|             "temperature": 0.7, | ||||
|             "top_p": 0.85, | ||||
|             "repetition_penalty": 1.2, | ||||
|             "messages": [ | ||||
|                 { | ||||
|                     "role": "system", | ||||
|                     "content": [ | ||||
|                         {"type": "text", "text": "You are a helpful assistant."} | ||||
|                     ], | ||||
|                 }, | ||||
|                 {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, | ||||
|             ], | ||||
|         } | ||||
|         response = requests.post(url, json=chat_post_data) | ||||
|         response_body = response.text | ||||
|         self.assertIn("id", response_body) | ||||
|         self.assertIn("choices", response_body) | ||||
|  | ||||
|     def test_handle_models(self): | ||||
|         url = f"http://localhost:{self.port}/v1/models" | ||||
|         response = requests.get(url) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Anchen
					Anchen