Server loads the model on demand from the request (#851)

This commit is contained in:
Angelos Katharopoulos 2024-06-27 11:37:57 -07:00 committed by GitHub
parent 538339b599
commit f212b770d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 28 deletions

View File

@ -8,6 +8,7 @@ import uuid
import warnings
from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import mlx.core as mx
@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
self.cli_args = cli_args
self.model_key = None
self.model = None
self.tokenizer = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
raise RuntimeError(
"Local models must be relative to the current working dir."
)
def load(self, model_path):
if self.model_key == model_path:
return self.model, self.tokenizer
# Remove the old model if it exists.
self.model = None
self.tokenizer = None
# Building tokenizer_config
tokenizer_config = {
"trust_remote_code": True if self.cli_args.trust_remote_code else None
}
if self.cli_args.chat_template:
tokenizer_config["chat_template"] = self.cli_args.chat_template
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=self.cli_args.adapter_path,
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = model_path
self.model = model
self.tokenizer = tokenizer
return self.model, self.tokenizer
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
"""
Create static request specific metadata
"""
self.model = model
self.tokenizer = tokenizer
self.created = int(time.time())
self.model_provider = model_provider
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler):
self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters()
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
except:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Get stop id sequences, if provided
stop_words = self.body.get("stop")
stop_words = stop_words or []
@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler):
def run(
host: str,
port: int,
model: nn.Module,
tokenizer: TokenizerWrapper,
model_provider: ModelProvider,
server_class=HTTPServer,
handler_class=APIHandler,
):
server_address = (host, port)
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "
@ -536,7 +599,6 @@ def main():
parser.add_argument(
"--model",
type=str,
required=True,
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
@ -598,20 +660,7 @@ def main():
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.chat_template:
tokenizer_config["chat_template"] = args.chat_template
model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
run(args.host, args.port, model, tokenizer)
run(args.host, args.port, ModelProvider(args))
if __name__ == "__main__":

View File

@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
from mlx_lm.utils import load
class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer
class TestServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
cls.model_provider = DummyModelProvider()
cls.server_address = ("localhost", 0)
cls.httpd = http.server.HTTPServer(
cls.server_address,
lambda *args, **kwargs: APIHandler(
cls.model, cls.tokenizer, *args, **kwargs
),
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
)
cls.port = cls.httpd.server_port
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)