Server loads the model on demand from the request (#851)

This commit is contained in:
Angelos Katharopoulos
2024-06-27 11:37:57 -07:00
committed by GitHub
parent 538339b599
commit f212b770d8
2 changed files with 82 additions and 28 deletions

View File

@@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
from mlx_lm.utils import load
class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer
class TestServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
cls.model_provider = DummyModelProvider()
cls.server_address = ("localhost", 0)
cls.httpd = http.server.HTTPServer(
cls.server_address,
lambda *args, **kwargs: APIHandler(
cls.model, cls.tokenizer, *args, **kwargs
),
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
)
cls.port = cls.httpd.server_port
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)