# Copyright © 2024 Apple Inc. import http import json import threading import unittest import requests from mlx_lm.server import APIHandler from mlx_lm.utils import load class DummyModelProvider: def __init__(self): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) def load(self, model, adapter=None): assert model in ["default_model", "chat_model"] return self.model, self.tokenizer class TestServer(unittest.TestCase): @classmethod def setUpClass(cls): cls.model_provider = DummyModelProvider() cls.server_address = ("localhost", 0) cls.httpd = http.server.HTTPServer( cls.server_address, lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs), ) cls.port = cls.httpd.server_port cls.server_thread = threading.Thread(target=cls.httpd.serve_forever) cls.server_thread.daemon = True cls.server_thread.start() @classmethod def tearDownClass(cls): cls.httpd.shutdown() cls.httpd.server_close() cls.server_thread.join() def test_handle_completions(self): url = f"http://localhost:{self.port}/v1/completions" post_data = { "model": "default_model", "prompt": "Once upon a time", "max_tokens": 10, "temperature": 0.5, "top_p": 0.9, "repetition_penalty": 1.1, "repetition_context_size": 20, "stop": "stop sequence", } response = requests.post(url, json=post_data) response_body = response.text self.assertIn("id", response_body) self.assertIn("choices", response_body) def test_handle_chat_completions(self): url = f"http://localhost:{self.port}/v1/chat/completions" chat_post_data = { "model": "chat_model", "max_tokens": 10, "temperature": 0.7, "top_p": 0.85, "repetition_penalty": 1.2, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}, ], } response = requests.post(url, json=chat_post_data) response_body = response.text self.assertIn("id", response_body) self.assertIn("choices", response_body) def test_handle_models(self): url = f"http://localhost:{self.port}/v1/models" response = requests.get(url) self.assertEqual(response.status_code, 200) response_body = json.loads(response.text) self.assertEqual(response_body["object"], "list") self.assertIsInstance(response_body["data"], list) self.assertGreater(len(response_body["data"]), 0) model = response_body["data"][0] self.assertIn("id", model) self.assertEqual(model["object"], "model") self.assertIn("created", model) def test_sequence_overlap(self): from mlx_lm.server import sequence_overlap self.assertTrue(sequence_overlap([1], [1])) self.assertTrue(sequence_overlap([1, 2], [1, 2])) self.assertTrue(sequence_overlap([1, 3], [3, 4])) self.assertTrue(sequence_overlap([1, 2, 3], [2, 3])) self.assertFalse(sequence_overlap([1], [2])) self.assertFalse(sequence_overlap([1, 2], [3, 4])) self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3])) if __name__ == "__main__": unittest.main()