mlx-examples/llms/tests/test_server.py
tidely df744c98e6
Predict stop sequence matches during streaming (#541)
* Predict stop sequence matches during streaming

Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met.

* fix typo

* Change sequence_overlap logic

* range isn't inclusive, add 1 to max_overlap

* Add test_server.py

Added a test for the sequence_overlap method

* nits

* eos sequence

* finalize

---------

Co-authored-by: Y4hL <43219534+Y4hL@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-06 15:24:15 -07:00

95 lines
2.9 KiB
Python

# Copyright © 2024 Apple Inc.
import http
import threading
import unittest
import requests
from mlx_lm.server import APIHandler
from mlx_lm.utils import load
class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer
class TestServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_provider = DummyModelProvider()
cls.server_address = ("localhost", 0)
cls.httpd = http.server.HTTPServer(
cls.server_address,
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
)
cls.port = cls.httpd.server_port
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
cls.server_thread.daemon = True
cls.server_thread.start()
@classmethod
def tearDownClass(cls):
cls.httpd.shutdown()
cls.httpd.server_close()
cls.server_thread.join()
def test_handle_completions(self):
url = f"http://localhost:{self.port}/v1/completions"
post_data = {
"model": "default_model",
"prompt": "Once upon a time",
"max_tokens": 10,
"temperature": 0.5,
"top_p": 0.9,
"repetition_penalty": 1.1,
"repetition_context_size": 20,
"stop": "stop sequence",
}
response = requests.post(url, json=post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_chat_completions(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap
self.assertTrue(sequence_overlap([1], [1]))
self.assertTrue(sequence_overlap([1, 2], [1, 2]))
self.assertTrue(sequence_overlap([1, 3], [3, 4]))
self.assertTrue(sequence_overlap([1, 2, 3], [2, 3]))
self.assertFalse(sequence_overlap([1], [2]))
self.assertFalse(sequence_overlap([1, 2], [3, 4]))
self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3]))
if __name__ == "__main__":
unittest.main()