mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Server loads the model on demand from the request (#851)
This commit is contained in:
parent
538339b599
commit
f212b770d8
@ -8,6 +8,7 @@ import uuid
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
class ModelProvider:
|
||||
def __init__(self, cli_args: argparse.Namespace):
|
||||
"""Load models on demand and persist them across the whole process."""
|
||||
self.cli_args = cli_args
|
||||
self.model_key = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Preload the default model if it is provided
|
||||
if self.cli_args.model is not None:
|
||||
self.load("default_model")
|
||||
|
||||
def _validate_model_path(self, model_path: str):
|
||||
model_path = Path(model_path)
|
||||
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
|
||||
raise RuntimeError(
|
||||
"Local models must be relative to the current working dir."
|
||||
)
|
||||
|
||||
def load(self, model_path):
|
||||
if self.model_key == model_path:
|
||||
return self.model, self.tokenizer
|
||||
|
||||
# Remove the old model if it exists.
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {
|
||||
"trust_remote_code": True if self.cli_args.trust_remote_code else None
|
||||
}
|
||||
if self.cli_args.chat_template:
|
||||
tokenizer_config["chat_template"] = self.cli_args.chat_template
|
||||
|
||||
if model_path == "default_model" and self.cli_args.model is not None:
|
||||
model, tokenizer = load(
|
||||
self.cli_args.model,
|
||||
adapter_path=self.cli_args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
else:
|
||||
self._validate_model_path(model_path)
|
||||
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
|
||||
|
||||
if self.cli_args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
self.model_key = model_path
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
return self.model, self.tokenizer
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
|
||||
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
|
||||
"""
|
||||
Create static request specific metadata
|
||||
"""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.created = int(time.time())
|
||||
self.model_provider = model_provider
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _set_cors_headers(self):
|
||||
@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.logprobs = self.body.get("logprobs", -1)
|
||||
self.validate_model_parameters()
|
||||
|
||||
# Load the model if needed
|
||||
try:
|
||||
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
|
||||
except:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not Found")
|
||||
return
|
||||
|
||||
# Get stop id sequences, if provided
|
||||
stop_words = self.body.get("stop")
|
||||
stop_words = stop_words or []
|
||||
@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
def run(
|
||||
host: str,
|
||||
port: int,
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
model_provider: ModelProvider,
|
||||
server_class=HTTPServer,
|
||||
handler_class=APIHandler,
|
||||
):
|
||||
server_address = (host, port)
|
||||
httpd = server_class(
|
||||
server_address,
|
||||
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
|
||||
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
|
||||
)
|
||||
warnings.warn(
|
||||
"mlx_lm.server is not recommended for production as "
|
||||
@ -536,7 +599,6 @@ def main():
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The path to the MLX model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -598,20 +660,7 @@ def main():
|
||||
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
if args.chat_template:
|
||||
tokenizer_config["chat_template"] = args.chat_template
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
run(args.host, args.port, model, tokenizer)
|
||||
run(args.host, args.port, ModelProvider(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
|
||||
from mlx_lm.utils import load
|
||||
|
||||
|
||||
class DummyModelProvider:
|
||||
def __init__(self):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
def load(self, model):
|
||||
assert model in ["default_model", "chat_model"]
|
||||
return self.model, self.tokenizer
|
||||
|
||||
|
||||
class TestServer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
cls.model_provider = DummyModelProvider()
|
||||
cls.server_address = ("localhost", 0)
|
||||
cls.httpd = http.server.HTTPServer(
|
||||
cls.server_address,
|
||||
lambda *args, **kwargs: APIHandler(
|
||||
cls.model, cls.tokenizer, *args, **kwargs
|
||||
),
|
||||
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
|
||||
)
|
||||
cls.port = cls.httpd.server_port
|
||||
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
|
||||
|
Loading…
Reference in New Issue
Block a user