mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Server loads the model on demand from the request (#851)
This commit is contained in:

committed by
GitHub

parent
538339b599
commit
f212b770d8
@@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
|
||||
from mlx_lm.utils import load
|
||||
|
||||
|
||||
class DummyModelProvider:
|
||||
def __init__(self):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
def load(self, model):
|
||||
assert model in ["default_model", "chat_model"]
|
||||
return self.model, self.tokenizer
|
||||
|
||||
|
||||
class TestServer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
cls.model_provider = DummyModelProvider()
|
||||
cls.server_address = ("localhost", 0)
|
||||
cls.httpd = http.server.HTTPServer(
|
||||
cls.server_address,
|
||||
lambda *args, **kwargs: APIHandler(
|
||||
cls.model, cls.tokenizer, *args, **kwargs
|
||||
),
|
||||
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
|
||||
)
|
||||
cls.port = cls.httpd.server_port
|
||||
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
|
||||
|
Reference in New Issue
Block a user