diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 19f3f46a..b53971a3 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -8,6 +8,7 @@ import uuid import warnings from functools import lru_cache from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union import mlx.core as mx @@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +class ModelProvider: + def __init__(self, cli_args: argparse.Namespace): + """Load models on demand and persist them across the whole process.""" + self.cli_args = cli_args + self.model_key = None + self.model = None + self.tokenizer = None + + # Preload the default model if it is provided + if self.cli_args.model is not None: + self.load("default_model") + + def _validate_model_path(self, model_path: str): + model_path = Path(model_path) + if model_path.exists() and not model_path.is_relative_to(Path.cwd()): + raise RuntimeError( + "Local models must be relative to the current working dir." + ) + + def load(self, model_path): + if self.model_key == model_path: + return self.model, self.tokenizer + + # Remove the old model if it exists. + self.model = None + self.tokenizer = None + + # Building tokenizer_config + tokenizer_config = { + "trust_remote_code": True if self.cli_args.trust_remote_code else None + } + if self.cli_args.chat_template: + tokenizer_config["chat_template"] = self.cli_args.chat_template + + if model_path == "default_model" and self.cli_args.model is not None: + model, tokenizer = load( + self.cli_args.model, + adapter_path=self.cli_args.adapter_path, + tokenizer_config=tokenizer_config, + ) + else: + self._validate_model_path(model_path) + model, tokenizer = load(model_path, tokenizer_config=tokenizer_config) + + if self.cli_args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + + self.model_key = model_path + self.model = model + self.tokenizer = tokenizer + + return self.model, self.tokenizer + + class APIHandler(BaseHTTPRequestHandler): - def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs): + def __init__(self, model_provider: ModelProvider, *args, **kwargs): """ Create static request specific metadata """ - self.model = model - self.tokenizer = tokenizer self.created = int(time.time()) + self.model_provider = model_provider super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler): self.logprobs = self.body.get("logprobs", -1) self.validate_model_parameters() + # Load the model if needed + try: + self.model, self.tokenizer = self.model_provider.load(self.requested_model) + except: + self._set_completion_headers(404) + self.end_headers() + self.wfile.write(b"Not Found") + return + # Get stop id sequences, if provided stop_words = self.body.get("stop") stop_words = stop_words or [] @@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler): def run( host: str, port: int, - model: nn.Module, - tokenizer: TokenizerWrapper, + model_provider: ModelProvider, server_class=HTTPServer, handler_class=APIHandler, ): server_address = (host, port) httpd = server_class( server_address, - lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs), + lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), ) warnings.warn( "mlx_lm.server is not recommended for production as " @@ -536,7 +599,6 @@ def main(): parser.add_argument( "--model", type=str, - required=True, help="The path to the MLX model weights, tokenizer, and config", ) parser.add_argument( @@ -598,20 +660,7 @@ def main(): logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB") mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Building tokenizer_config - tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} - if args.chat_template: - tokenizer_config["chat_template"] = args.chat_template - - model, tokenizer = load( - args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config - ) - - if args.use_default_chat_template: - if tokenizer.chat_template is None: - tokenizer.chat_template = tokenizer.default_chat_template - - run(args.host, args.port, model, tokenizer) + run(args.host, args.port, ModelProvider(args)) if __name__ == "__main__": diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index 998ad1c7..4d71a5a3 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler from mlx_lm.utils import load +class DummyModelProvider: + def __init__(self): + HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + self.model, self.tokenizer = load(HF_MODEL_PATH) + + def load(self, model): + assert model in ["default_model", "chat_model"] + return self.model, self.tokenizer + + class TestServer(unittest.TestCase): @classmethod def setUpClass(cls): - HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - - cls.model, cls.tokenizer = load(HF_MODEL_PATH) - + cls.model_provider = DummyModelProvider() cls.server_address = ("localhost", 0) cls.httpd = http.server.HTTPServer( cls.server_address, - lambda *args, **kwargs: APIHandler( - cls.model, cls.tokenizer, *args, **kwargs - ), + lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs), ) cls.port = cls.httpd.server_port cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)