Server loads the model on demand from the request (#851)

This commit is contained in:
Angelos Katharopoulos 2024-06-27 11:37:57 -07:00 committed by GitHub
parent 538339b599
commit f212b770d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 28 deletions

View File

@ -8,6 +8,7 @@ import uuid
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
self.cli_args = cli_args
self.model_key = None
self.model = None
self.tokenizer = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
raise RuntimeError(
"Local models must be relative to the current working dir."
)
def load(self, model_path):
if self.model_key == model_path:
return self.model, self.tokenizer
# Remove the old model if it exists.
self.model = None
self.tokenizer = None
# Building tokenizer_config
tokenizer_config = {
"trust_remote_code": True if self.cli_args.trust_remote_code else None
}
if self.cli_args.chat_template:
tokenizer_config["chat_template"] = self.cli_args.chat_template
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=self.cli_args.adapter_path,
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = model_path
self.model = model
self.tokenizer = tokenizer
return self.model, self.tokenizer
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs): def __init__(self, model_provider: ModelProvider, *args, **kwargs):
""" """
Create static request specific metadata Create static request specific metadata
""" """
self.model = model
self.tokenizer = tokenizer
self.created = int(time.time()) self.created = int(time.time())
self.model_provider = model_provider
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _set_cors_headers(self): def _set_cors_headers(self):
@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler):
self.logprobs = self.body.get("logprobs", -1) self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters() self.validate_model_parameters()
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
except:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Get stop id sequences, if provided # Get stop id sequences, if provided
stop_words = self.body.get("stop") stop_words = self.body.get("stop")
stop_words = stop_words or [] stop_words = stop_words or []
@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler):
def run( def run(
host: str, host: str,
port: int, port: int,
model: nn.Module, model_provider: ModelProvider,
tokenizer: TokenizerWrapper,
server_class=HTTPServer, server_class=HTTPServer,
handler_class=APIHandler, handler_class=APIHandler,
): ):
server_address = (host, port) server_address = (host, port)
httpd = server_class( httpd = server_class(
server_address, server_address,
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs), lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
) )
warnings.warn( warnings.warn(
"mlx_lm.server is not recommended for production as " "mlx_lm.server is not recommended for production as "
@ -536,7 +599,6 @@ def main():
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
required=True,
help="The path to the MLX model weights, tokenizer, and config", help="The path to the MLX model weights, tokenizer, and config",
) )
parser.add_argument( parser.add_argument(
@ -598,20 +660,7 @@ def main():
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB") logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config run(args.host, args.port, ModelProvider(args))
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.chat_template:
tokenizer_config["chat_template"] = args.chat_template
model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
run(args.host, args.port, model, tokenizer)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
from mlx_lm.utils import load from mlx_lm.utils import load
class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer
class TestServer(unittest.TestCase): class TestServer(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" cls.model_provider = DummyModelProvider()
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
cls.server_address = ("localhost", 0) cls.server_address = ("localhost", 0)
cls.httpd = http.server.HTTPServer( cls.httpd = http.server.HTTPServer(
cls.server_address, cls.server_address,
lambda *args, **kwargs: APIHandler( lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
cls.model, cls.tokenizer, *args, **kwargs
),
) )
cls.port = cls.httpd.server_port cls.port = cls.httpd.server_port
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever) cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)