Predict stop sequence matches during streaming (#541)

* Predict stop sequence matches during streaming

Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met.

* fix typo

* Change sequence_overlap logic

* range isn't inclusive, add 1 to max_overlap

* Add test_server.py

Added a test for the sequence_overlap method

* nits

* eos sequence

* finalize

---------

Co-authored-by: Y4hL <43219534+Y4hL@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
tidely
2024-08-07 01:24:15 +03:00
committed by GitHub
parent 8fa12b0058
commit df744c98e6
2 changed files with 47 additions and 22 deletions

View File

@@ -1,3 +1,4 @@
# Copyright © 2024 Apple Inc.
import http
import threading
import unittest
@@ -76,6 +77,18 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap
self.assertTrue(sequence_overlap([1], [1]))
self.assertTrue(sequence_overlap([1, 2], [1, 2]))
self.assertTrue(sequence_overlap([1, 3], [3, 4]))
self.assertTrue(sequence_overlap([1, 2, 3], [2, 3]))
self.assertFalse(sequence_overlap([1], [2]))
self.assertFalse(sequence_overlap([1, 2], [3, 4]))
self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3]))
if __name__ == "__main__":
unittest.main()