mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Predict stop sequence matches during streaming (#541)
* Predict stop sequence matches during streaming Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met. * fix typo * Change sequence_overlap logic * range isn't inclusive, add 1 to max_overlap * Add test_server.py Added a test for the sequence_overlap method * nits * eos sequence * finalize --------- Co-authored-by: Y4hL <43219534+Y4hL@users.noreply.github.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
8fa12b0058
commit
df744c98e6
@ -8,7 +8,7 @@ import uuid
|
||||
import warnings
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Union
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@ -54,6 +54,21 @@ def stopping_criteria(
|
||||
return StopCondition(stop_met=False, trim_length=0)
|
||||
|
||||
|
||||
def sequence_overlap(s1: Sequence, s2: Sequence) -> bool:
|
||||
"""
|
||||
Checks if a suffix of s1 has overlap with a prefix of s2
|
||||
|
||||
Args:
|
||||
s1 (Sequence): The first sequence
|
||||
s2 (Sequence): The second sequence
|
||||
|
||||
Returns:
|
||||
bool: If the two sequences have overlap
|
||||
"""
|
||||
max_overlap = min(len(s1), len(s2))
|
||||
return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1))
|
||||
|
||||
|
||||
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
default_role_mapping = {
|
||||
"system_prompt": (
|
||||
@ -462,12 +477,13 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
|
||||
Generate response to prompt and foward it to the client using a Server
|
||||
Sent Events (SSE) stream.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
stop_id_sequences (List[List[int]]):
|
||||
A list of stop words passed to the stopping_criteria function
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
||||
the stopping_criteria function
|
||||
"""
|
||||
# No additional headers are needed, call end_headers
|
||||
self.end_headers()
|
||||
@ -476,12 +492,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
detokenizer.reset()
|
||||
tokens = []
|
||||
|
||||
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
||||
# to check for stop conditions before writing to the stream.
|
||||
stop_sequence_buffer = []
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting stream:")
|
||||
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
@ -496,11 +509,6 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
stop_sequence_buffer.append(token)
|
||||
|
||||
# Continue generating tokens until buffer is as large as the longest stop_id_sequence
|
||||
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
||||
continue
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
@ -514,21 +522,25 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
)
|
||||
break
|
||||
|
||||
# If the end of tokens overlaps with a stop sequence, generate new
|
||||
# tokens until we know if the stop sequence is hit or not
|
||||
if any(
|
||||
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
|
||||
):
|
||||
continue
|
||||
|
||||
new_text = detokenizer.last_segment
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
stop_sequence_buffer = []
|
||||
|
||||
# check is there any remaining text to send
|
||||
if stop_sequence_buffer:
|
||||
next_chunk = (
|
||||
detokenizer.last_segment
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
response = self.generate_response(next_chunk, "length")
|
||||
|
||||
detokenizer.finalize()
|
||||
last_segment = detokenizer.last_segment
|
||||
if last_segment:
|
||||
if stop_sequence_suffix is not None:
|
||||
last_segment = last_segment[: -len(stop_sequence_suffix)]
|
||||
response = self.generate_response(last_segment, "length")
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
import http
|
||||
import threading
|
||||
import unittest
|
||||
@ -76,6 +77,18 @@ class TestServer(unittest.TestCase):
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_sequence_overlap(self):
|
||||
from mlx_lm.server import sequence_overlap
|
||||
|
||||
self.assertTrue(sequence_overlap([1], [1]))
|
||||
self.assertTrue(sequence_overlap([1, 2], [1, 2]))
|
||||
self.assertTrue(sequence_overlap([1, 3], [3, 4]))
|
||||
self.assertTrue(sequence_overlap([1, 2, 3], [2, 3]))
|
||||
|
||||
self.assertFalse(sequence_overlap([1], [2]))
|
||||
self.assertFalse(sequence_overlap([1, 2], [3, 4]))
|
||||
self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user