From df744c98e67f94fba8a893cb5aba4fdb735f5f4a Mon Sep 17 00:00:00 2001 From: tidely <43219534+tidely@users.noreply.github.com> Date: Wed, 7 Aug 2024 01:24:15 +0300 Subject: [PATCH] Predict stop sequence matches during streaming (#541) * Predict stop sequence matches during streaming Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met. * fix typo * Change sequence_overlap logic * range isn't inclusive, add 1 to max_overlap * Add test_server.py Added a test for the sequence_overlap method * nits * eos sequence * finalize --------- Co-authored-by: Y4hL <43219534+Y4hL@users.noreply.github.com> Co-authored-by: Awni Hannun --- llms/mlx_lm/server.py | 56 ++++++++++++++++++++++++--------------- llms/tests/test_server.py | 13 +++++++++ 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 7456399c..79ac1836 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -8,7 +8,7 @@ import uuid import warnings from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union import mlx.core as mx @@ -54,6 +54,21 @@ def stopping_criteria( return StopCondition(stop_met=False, trim_length=0) +def sequence_overlap(s1: Sequence, s2: Sequence) -> bool: + """ + Checks if a suffix of s1 has overlap with a prefix of s2 + + Args: + s1 (Sequence): The first sequence + s2 (Sequence): The second sequence + + Returns: + bool: If the two sequences have overlap + """ + max_overlap = min(len(s1), len(s2)) + return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1)) + + def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): default_role_mapping = { "system_prompt": ( @@ -462,12 +477,13 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences: List[List[int]], ): """ - Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream. + Generate response to prompt and foward it to the client using a Server + Sent Events (SSE) stream. Args: prompt (mx.array): The prompt, in token form inside of a mlx array - stop_id_sequences (List[List[int]]): - A list of stop words passed to the stopping_criteria function + stop_id_sequences (List[List[int]]): A list of stop words passed to + the stopping_criteria function """ # No additional headers are needed, call end_headers self.end_headers() @@ -476,12 +492,9 @@ class APIHandler(BaseHTTPRequestHandler): detokenizer.reset() tokens = [] - max_stop_id_sequence_len = len(max(stop_id_sequences, default=[])) - # Buffer to store the last `max_stop_id_sequence_len` tokens - # to check for stop conditions before writing to the stream. - stop_sequence_buffer = [] stop_sequence_suffix = None logging.debug(f"Starting stream:") + for (token, _), _ in zip( generate_step( prompt=prompt, @@ -496,11 +509,6 @@ class APIHandler(BaseHTTPRequestHandler): detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) - stop_sequence_buffer.append(token) - - # Continue generating tokens until buffer is as large as the longest stop_id_sequence - if len(stop_sequence_buffer) < max_stop_id_sequence_len: - continue stop_condition = stopping_criteria( tokens, @@ -514,21 +522,25 @@ class APIHandler(BaseHTTPRequestHandler): ) break + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) + ): + continue + new_text = detokenizer.last_segment response = self.generate_response(new_text, None) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() - stop_sequence_buffer = [] # check is there any remaining text to send - if stop_sequence_buffer: - next_chunk = ( - detokenizer.last_segment - if stop_sequence_suffix is None - else detokenizer.last_segment[: -len(stop_sequence_suffix)] - ) - response = self.generate_response(next_chunk, "length") - + detokenizer.finalize() + last_segment = detokenizer.last_segment + if last_segment: + if stop_sequence_suffix is not None: + last_segment = last_segment[: -len(stop_sequence_suffix)] + response = self.generate_response(last_segment, "length") self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index b8047eaa..baea664a 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -1,3 +1,4 @@ +# Copyright © 2024 Apple Inc. import http import threading import unittest @@ -76,6 +77,18 @@ class TestServer(unittest.TestCase): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_sequence_overlap(self): + from mlx_lm.server import sequence_overlap + + self.assertTrue(sequence_overlap([1], [1])) + self.assertTrue(sequence_overlap([1, 2], [1, 2])) + self.assertTrue(sequence_overlap([1, 3], [3, 4])) + self.assertTrue(sequence_overlap([1, 2, 3], [2, 3])) + + self.assertFalse(sequence_overlap([1], [2])) + self.assertFalse(sequence_overlap([1, 2], [3, 4])) + self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3])) + if __name__ == "__main__": unittest.main()