diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 4124dfa0..334a17d1 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -133,15 +133,15 @@ def process_message_content(messages): processed_messages = [] for message in messages: message_copy = message.copy() - if "content" in message_copy: - flattened_text = "" - if isinstance(message_copy["content"], list): - for content_fragment in message_copy["content"]: - if content_fragment["type"] == "text" and "text" in content_fragment: - flattened_text += content_fragment["text"] - else: - raise ValueError("Only 'text' content type is supported.") - message_copy["content"] = flattened_text + if "content" in message_copy and isinstance(message_copy["content"], list): + text_fragments = [ + fragment["text"] + for fragment in message_copy["content"] + if fragment.get("type") == "text" + ] + if len(text_fragments) != len(message_copy["content"]): + raise ValueError("Only 'text' content type is supported.") + message_copy["content"] = "".join(text_fragments) processed_messages.append(message_copy) return processed_messages