Predict stop sequence matches during streaming (#541)

* Predict stop sequence matches during streaming

Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met.

* fix typo

* Change sequence_overlap logic

* range isn't inclusive, add 1 to max_overlap

* Add test_server.py

Added a test for the sequence_overlap method

* nits

* eos sequence

* finalize

---------

Co-authored-by: Y4hL <43219534+Y4hL@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
tidely 2024-08-07 01:24:15 +03:00 committed by GitHub
parent 8fa12b0058
commit df744c98e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 22 deletions

View File

@ -8,7 +8,7 @@ import uuid
import warnings
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Union
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx
@ -54,6 +54,21 @@ def stopping_criteria(
return StopCondition(stop_met=False, trim_length=0)
def sequence_overlap(s1: Sequence, s2: Sequence) -> bool:
"""
Checks if a suffix of s1 has overlap with a prefix of s2
Args:
s1 (Sequence): The first sequence
s2 (Sequence): The second sequence
Returns:
bool: If the two sequences have overlap
"""
max_overlap = min(len(s1), len(s2))
return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1))
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": (
@ -462,12 +477,13 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
Generate response to prompt and foward it to the client using a Server
Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers()
@ -476,12 +492,9 @@ class APIHandler(BaseHTTPRequestHandler):
detokenizer.reset()
tokens = []
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
# Buffer to store the last `max_stop_id_sequence_len` tokens
# to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
generate_step(
prompt=prompt,
@ -496,11 +509,6 @@ class APIHandler(BaseHTTPRequestHandler):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
stop_sequence_buffer.append(token)
# Continue generating tokens until buffer is as large as the longest stop_id_sequence
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
continue
stop_condition = stopping_criteria(
tokens,
@ -514,21 +522,25 @@ class APIHandler(BaseHTTPRequestHandler):
)
break
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
):
continue
new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
stop_sequence_buffer = []
# check is there any remaining text to send
if stop_sequence_buffer:
next_chunk = (
detokenizer.last_segment
if stop_sequence_suffix is None
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
)
response = self.generate_response(next_chunk, "length")
detokenizer.finalize()
last_segment = detokenizer.last_segment
if last_segment:
if stop_sequence_suffix is not None:
last_segment = last_segment[: -len(stop_sequence_suffix)]
response = self.generate_response(last_segment, "length")
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()

View File

@ -1,3 +1,4 @@
# Copyright © 2024 Apple Inc.
import http
import threading
import unittest
@ -76,6 +77,18 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap
self.assertTrue(sequence_overlap([1], [1]))
self.assertTrue(sequence_overlap([1, 2], [1, 2]))
self.assertTrue(sequence_overlap([1, 3], [3, 4]))
self.assertTrue(sequence_overlap([1, 2, 3], [2, 3]))
self.assertFalse(sequence_overlap([1], [2]))
self.assertFalse(sequence_overlap([1, 2], [3, 4]))
self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3]))
if __name__ == "__main__":
unittest.main()