mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Server loads the model on demand from the request (#851)
This commit is contained in:
parent
538339b599
commit
f212b770d8
@ -8,6 +8,7 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -81,14 +82,68 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
|||||||
return prompt.rstrip()
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider:
|
||||||
|
def __init__(self, cli_args: argparse.Namespace):
|
||||||
|
"""Load models on demand and persist them across the whole process."""
|
||||||
|
self.cli_args = cli_args
|
||||||
|
self.model_key = None
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
# Preload the default model if it is provided
|
||||||
|
if self.cli_args.model is not None:
|
||||||
|
self.load("default_model")
|
||||||
|
|
||||||
|
def _validate_model_path(self, model_path: str):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Local models must be relative to the current working dir."
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, model_path):
|
||||||
|
if self.model_key == model_path:
|
||||||
|
return self.model, self.tokenizer
|
||||||
|
|
||||||
|
# Remove the old model if it exists.
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
# Building tokenizer_config
|
||||||
|
tokenizer_config = {
|
||||||
|
"trust_remote_code": True if self.cli_args.trust_remote_code else None
|
||||||
|
}
|
||||||
|
if self.cli_args.chat_template:
|
||||||
|
tokenizer_config["chat_template"] = self.cli_args.chat_template
|
||||||
|
|
||||||
|
if model_path == "default_model" and self.cli_args.model is not None:
|
||||||
|
model, tokenizer = load(
|
||||||
|
self.cli_args.model,
|
||||||
|
adapter_path=self.cli_args.adapter_path,
|
||||||
|
tokenizer_config=tokenizer_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._validate_model_path(model_path)
|
||||||
|
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
|
||||||
|
|
||||||
|
if self.cli_args.use_default_chat_template:
|
||||||
|
if tokenizer.chat_template is None:
|
||||||
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
|
||||||
|
self.model_key = model_path
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
return self.model, self.tokenizer
|
||||||
|
|
||||||
|
|
||||||
class APIHandler(BaseHTTPRequestHandler):
|
class APIHandler(BaseHTTPRequestHandler):
|
||||||
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
|
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create static request specific metadata
|
Create static request specific metadata
|
||||||
"""
|
"""
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.created = int(time.time())
|
self.created = int(time.time())
|
||||||
|
self.model_provider = model_provider
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _set_cors_headers(self):
|
def _set_cors_headers(self):
|
||||||
@ -148,6 +203,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.logprobs = self.body.get("logprobs", -1)
|
self.logprobs = self.body.get("logprobs", -1)
|
||||||
self.validate_model_parameters()
|
self.validate_model_parameters()
|
||||||
|
|
||||||
|
# Load the model if needed
|
||||||
|
try:
|
||||||
|
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
|
||||||
|
except:
|
||||||
|
self._set_completion_headers(404)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"Not Found")
|
||||||
|
return
|
||||||
|
|
||||||
# Get stop id sequences, if provided
|
# Get stop id sequences, if provided
|
||||||
stop_words = self.body.get("stop")
|
stop_words = self.body.get("stop")
|
||||||
stop_words = stop_words or []
|
stop_words = stop_words or []
|
||||||
@ -513,15 +577,14 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
def run(
|
def run(
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
model: nn.Module,
|
model_provider: ModelProvider,
|
||||||
tokenizer: TokenizerWrapper,
|
|
||||||
server_class=HTTPServer,
|
server_class=HTTPServer,
|
||||||
handler_class=APIHandler,
|
handler_class=APIHandler,
|
||||||
):
|
):
|
||||||
server_address = (host, port)
|
server_address = (host, port)
|
||||||
httpd = server_class(
|
httpd = server_class(
|
||||||
server_address,
|
server_address,
|
||||||
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
|
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
|
||||||
)
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"mlx_lm.server is not recommended for production as "
|
"mlx_lm.server is not recommended for production as "
|
||||||
@ -536,7 +599,6 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
|
||||||
help="The path to the MLX model weights, tokenizer, and config",
|
help="The path to the MLX model weights, tokenizer, and config",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -598,20 +660,7 @@ def main():
|
|||||||
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
|
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
|
||||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
# Building tokenizer_config
|
run(args.host, args.port, ModelProvider(args))
|
||||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
|
||||||
if args.chat_template:
|
|
||||||
tokenizer_config["chat_template"] = args.chat_template
|
|
||||||
|
|
||||||
model, tokenizer = load(
|
|
||||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
|
||||||
if tokenizer.chat_template is None:
|
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
|
||||||
|
|
||||||
run(args.host, args.port, model, tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -7,19 +7,24 @@ from mlx_lm.server import APIHandler
|
|||||||
from mlx_lm.utils import load
|
from mlx_lm.utils import load
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModelProvider:
|
||||||
|
def __init__(self):
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||||
|
|
||||||
|
def load(self, model):
|
||||||
|
assert model in ["default_model", "chat_model"]
|
||||||
|
return self.model, self.tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestServer(unittest.TestCase):
|
class TestServer(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
cls.model_provider = DummyModelProvider()
|
||||||
|
|
||||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
|
||||||
|
|
||||||
cls.server_address = ("localhost", 0)
|
cls.server_address = ("localhost", 0)
|
||||||
cls.httpd = http.server.HTTPServer(
|
cls.httpd = http.server.HTTPServer(
|
||||||
cls.server_address,
|
cls.server_address,
|
||||||
lambda *args, **kwargs: APIHandler(
|
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
|
||||||
cls.model, cls.tokenizer, *args, **kwargs
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
cls.port = cls.httpd.server_port
|
cls.port = cls.httpd.server_port
|
||||||
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
|
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
|
||||||
|
Loading…
Reference in New Issue
Block a user