diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 7456399c..79ac1836 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -8,7 +8,7 @@ import uuid import warnings from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union import mlx.core as mx @@ -54,6 +54,21 @@ def stopping_criteria( return StopCondition(stop_met=False, trim_length=0) +def sequence_overlap(s1: Sequence, s2: Sequence) -> bool: + """ + Checks if a suffix of s1 has overlap with a prefix of s2 + + Args: + s1 (Sequence): The first sequence + s2 (Sequence): The second sequence + + Returns: + bool: If the two sequences have overlap + """ + max_overlap = min(len(s1), len(s2)) + return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1)) + + def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): default_role_mapping = { "system_prompt": ( @@ -462,12 +477,13 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences: List[List[int]], ): """ - Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream. + Generate response to prompt and foward it to the client using a Server + Sent Events (SSE) stream. Args: prompt (mx.array): The prompt, in token form inside of a mlx array - stop_id_sequences (List[List[int]]): - A list of stop words passed to the stopping_criteria function + stop_id_sequences (List[List[int]]): A list of stop words passed to + the stopping_criteria function """ # No additional headers are needed, call end_headers self.end_headers() @@ -476,12 +492,9 @@ class APIHandler(BaseHTTPRequestHandler): detokenizer.reset() tokens = [] - max_stop_id_sequence_len = len(max(stop_id_sequences, default=[])) - # Buffer to store the last `max_stop_id_sequence_len` tokens - # to check for stop conditions before writing to the stream. - stop_sequence_buffer = [] stop_sequence_suffix = None logging.debug(f"Starting stream:") + for (token, _), _ in zip( generate_step( prompt=prompt, @@ -496,11 +509,6 @@ class APIHandler(BaseHTTPRequestHandler): detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) - stop_sequence_buffer.append(token) - - # Continue generating tokens until buffer is as large as the longest stop_id_sequence - if len(stop_sequence_buffer) < max_stop_id_sequence_len: - continue stop_condition = stopping_criteria( tokens, @@ -514,21 +522,25 @@ class APIHandler(BaseHTTPRequestHandler): ) break + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) + ): + continue + new_text = detokenizer.last_segment response = self.generate_response(new_text, None) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() - stop_sequence_buffer = [] # check is there any remaining text to send - if stop_sequence_buffer: - next_chunk = ( - detokenizer.last_segment - if stop_sequence_suffix is None - else detokenizer.last_segment[: -len(stop_sequence_suffix)] - ) - response = self.generate_response(next_chunk, "length") - + detokenizer.finalize() + last_segment = detokenizer.last_segment + if last_segment: + if stop_sequence_suffix is not None: + last_segment = last_segment[: -len(stop_sequence_suffix)] + response = self.generate_response(last_segment, "length") self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index b8047eaa..baea664a 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -1,3 +1,4 @@ +# Copyright © 2024 Apple Inc. import http import threading import unittest @@ -76,6 +77,18 @@ class TestServer(unittest.TestCase): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_sequence_overlap(self): + from mlx_lm.server import sequence_overlap + + self.assertTrue(sequence_overlap([1], [1])) + self.assertTrue(sequence_overlap([1, 2], [1, 2])) + self.assertTrue(sequence_overlap([1, 3], [3, 4])) + self.assertTrue(sequence_overlap([1, 2, 3], [2, 3])) + + self.assertFalse(sequence_overlap([1], [2])) + self.assertFalse(sequence_overlap([1, 2], [3, 4])) + self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3])) + if __name__ == "__main__": unittest.main()