mlx-examples/llms/mlx_lm/server.py
chaihahaha 12d1e968cc
Add support for IPv6
http.server supports IPv6, but the address has to be processed by _get_best_family
2025-02-23 02:13:11 +08:00

787 lines
27 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer, _get_best_family
from pathlib import Path
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
def stopping_criteria(
tokens: List[int],
stop_id_sequences: List[List[int]],
eos_token_id: Union[int, None],
) -> StopCondition:
"""
Determines whether the token generation should stop based on predefined
conditions.
Args:
tokens (List[int]): The current sequence of generated tokens.
stop_id_sequences (List[List[[int]]): A list of integer lists, each
representing a sequence of token IDs. If the end of the `tokens`
list matches any of these sequences, the generation should stop.
eos_token_id (Union[int, None]): The token ID that represents the
end-of-sequence. If the last token in `tokens` matches this, the
generation should stop.
Returns:
StopCondition: A named tuple indicating whether the stop condition has
been met (`stop_met`) and how many tokens should be trimmed from the
end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
if tokens[-len(stop_ids) :] == stop_ids:
return StopCondition(stop_met=True, trim_length=len(stop_ids))
return StopCondition(stop_met=False, trim_length=0)
def sequence_overlap(s1: Sequence, s2: Sequence) -> bool:
"""
Checks if a suffix of s1 has overlap with a prefix of s2
Args:
s1 (Sequence): The first sequence
s2 (Sequence): The second sequence
Returns:
bool: If the two sequences have overlap
"""
max_overlap = min(len(s1), len(s2))
return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1))
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant follows the given rules no matter what."
),
"system": "ASSISTANT's RULE: ",
"user": "USER: ",
"assistant": "ASSISTANT: ",
"stop": "\n",
}
role_mapping = role_mapping if role_mapping is not None else default_role_mapping
prompt = ""
for line in messages:
role_prefix = role_mapping.get(line["role"], "")
stop = role_mapping.get("stop", "")
content = line.get("content", "")
prompt += f"{role_prefix}{content}{stop}"
prompt += role_mapping.get("assistant", "")
return prompt.rstrip()
def process_message_content(messages):
"""
Convert message content to a format suitable for `apply_chat_template`.
The function operates on messages in place. It converts the 'content' field
to a string instead of a list of text fragments.
Args:
message_list (list): A list of dictionaries, where each dictionary may
have a 'content' key containing a list of dictionaries with 'type' and
'text' keys.
Raises:
ValueError: If the 'content' type is not supported or if 'text' is missing.
"""
for message in messages:
content = message["content"]
if isinstance(content, list):
text_fragments = [
fragment["text"] for fragment in content if fragment["type"] == "text"
]
if len(text_fragments) != len(content):
raise ValueError("Only 'text' content type is supported.")
message["content"] = "".join(text_fragments)
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
self.cli_args = cli_args
self.model_key = None
self.model = None
self.tokenizer = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
raise RuntimeError(
"Local models must be relative to the current working dir."
)
# Added in adapter_path to load dynamically
def load(self, model_path, adapter_path=None):
if self.model_key == (model_path, adapter_path):
return self.model, self.tokenizer
# Remove the old model if it exists.
self.model = None
self.tokenizer = None
self.model_key = None
# Building tokenizer_config
tokenizer_config = {
"trust_remote_code": True if self.cli_args.trust_remote_code else None
}
if self.cli_args.chat_template:
tokenizer_config["chat_template"] = self.cli_args.chat_template
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=(
adapter_path if adapter_path else self.cli_args.adapter_path
), # if the user doesn't change the model but adds an adapter path
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = (model_path, adapter_path)
self.model = model
self.tokenizer = tokenizer
return self.model, self.tokenizer
class APIHandler(BaseHTTPRequestHandler):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "*")
self.send_header("Access-Control-Allow-Headers", "*")
def _set_completion_headers(self, status_code: int = 200):
self.send_response(status_code)
self.send_header("Content-type", "application/json")
self._set_cors_headers()
def _set_stream_headers(self, status_code: int = 200):
self.send_response(status_code)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self._set_cors_headers()
def do_OPTIONS(self):
self._set_completion_headers(204)
self.end_headers()
def do_POST(self):
"""
Respond to a POST request from a client.
"""
endpoints = {
"/v1/completions": self.handle_text_completions,
"/v1/chat/completions": self.handle_chat_completions,
"/chat/completions": self.handle_chat_completions,
}
if self.path not in endpoints:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Fetch and parse request body
content_length = int(self.headers["Content-Length"])
raw_body = self.rfile.read(content_length)
self.body = json.loads(raw_body.decode())
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}")
assert isinstance(
self.body, dict
), f"Request should be dict, but got {type(self.body)}"
# Extract request parameters from the body
self.stream = self.body.get("stream", False)
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
self.logit_bias = self.body.get("logit_bias", None)
self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters()
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(
self.requested_model, self.adapter
)
except:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Get stop id sequences, if provided
stop_words = self.body.get("stop")
stop_words = stop_words or []
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
self.tokenizer.encode(stop_word, add_special_tokens=False)
for stop_word in stop_words
]
# Send header type
(
self._set_stream_headers(200)
if self.stream
else self._set_completion_headers(200)
)
# Call endpoint specific method
prompt = endpoints[self.path]()
self.handle_completion(prompt, stop_id_sequences)
def validate_model_parameters(self):
"""
Validate the model parameters passed in the request for the correct types and values.
"""
if not isinstance(self.stream, bool):
raise ValueError("stream must be a boolean")
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
raise ValueError("max_tokens must be a non-negative integer")
if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
raise ValueError("temperature must be a non-negative float")
if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
raise ValueError("top_p must be a float between 0 and 1")
if (
not isinstance(self.repetition_penalty, (float, int))
or self.repetition_penalty < 0
):
raise ValueError("repetition_penalty must be a non-negative float")
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
raise ValueError(
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
)
if (
not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0
):
raise ValueError("repetition_context_size must be a non-negative integer")
if self.logit_bias is not None:
if not isinstance(self.logit_bias, dict):
raise ValueError("logit_bias must be a dict of int to float")
try:
self.logit_bias = {int(k): v for k, v in self.logit_bias.items()}
except ValueError:
raise ValueError("logit_bias must be a dict of int to float")
if not isinstance(self.requested_model, str):
raise ValueError("model must be a string")
if self.adapter is not None and not isinstance(self.adapter, str):
raise ValueError("adapter must be a string")
def generate_response(
self,
text: str,
finish_reason: Union[Literal["length", "stop"], None],
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
token_logprobs: Optional[List[float]] = None,
top_tokens: Optional[List[Dict[int, float]]] = None,
tokens: Optional[List[int]] = None,
) -> dict:
"""
Generate a single response packet based on response type (stream or
not), completion type and parameters.
Args:
text (str): Text generated by model
finish_reason (Union[Literal["length", "stop"], None]): The reason the
response is being sent: "length", "stop" or `None`.
prompt_token_count (Optional[int]): The number of tokens in the prompt,
used to populate the "usage" field (not used when stream).
completion_token_count (Optional[int]): The number of tokens in the
response, used to populate the "usage" field (not used when stream).
token_logprobs (Optional[List[float]]): The log probabilities per token,
in token order.
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
tokens to logprobs for the top N tokens at each token position.
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
Returns:
dict: A dictionary containing the response, in the same format as
OpenAI's API.
"""
token_logprobs = token_logprobs if token_logprobs else []
top_logprobs = top_tokens if top_tokens else []
# Static response
response = {
"id": self.request_id,
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
"choices": [
{
"index": 0,
"logprobs": {
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"tokens": tokens,
},
"finish_reason": finish_reason,
}
],
}
if not self.stream:
if not (
isinstance(prompt_token_count, int)
and isinstance(completion_token_count, int)
):
raise ValueError(
"Response type is complete, but token counts not provided"
)
response["usage"] = {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count,
}
choice = response["choices"][0]
# Add dynamic response
if self.object_type.startswith("chat.completion"):
key_name = "delta" if self.stream else "message"
choice[key_name] = {"role": "assistant", "content": text}
elif self.object_type == "text_completion":
choice.update(text=text)
else:
ValueError(f"Unsupported response type: {self.object_type}")
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
tokens = []
finish_reason = "length"
stop_sequence_suffix = None
if self.stream:
self.end_headers()
logging.debug(f"Starting stream:")
else:
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
prompt = self.get_prompt_cache(prompt)
text = ""
tic = time.perf_counter()
sampler = make_sampler(self.temperature, top_p=self.top_p)
logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)
for gen_response in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=self.prompt_cache.cache,
):
segment = gen_response.text
text += segment
logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs
tokens.append(token)
if self.logprobs > 0:
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
stop_condition = stopping_criteria(
tokens, stop_id_sequences, self.tokenizer.eos_token_id
)
if stop_condition.stop_met:
finish_reason = "stop"
if stop_condition.trim_length:
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
text = text[: -len(stop_sequence_suffix)]
break
if self.stream:
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(
sequence_overlap(tokens, sequence)
for sequence in stop_id_sequences
)
):
continue
elif segment:
response = self.generate_response(segment, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
if self.stream:
response = self.generate_response(segment, finish_reason)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
else:
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def completion_usage_response(
self,
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
):
response = {
"id": self.request_id,
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
"choices": [],
"usage": {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count,
},
}
return response
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
Returns:
mx.array: A mx.array of the tokenized prompt from the request body
"""
body = self.body
assert "messages" in body, "Request did not contain messages"
# Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template:
messages = body["messages"]
process_message_content(messages)
prompt = self.tokenizer.apply_chat_template(
messages,
body.get("tools", None),
add_generation_prompt=True,
)
else:
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return prompt
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
Returns:
mx.array: A mx.array of the tokenized prompt from the request body
"""
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
return self.tokenizer.encode(self.body["prompt"])
def do_GET(self):
"""
Respond to a GET request from a client.
"""
if self.path == "/v1/models":
self.handle_models_request()
else:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
def handle_models_request(self):
"""
Handle a GET request for the /v1/models endpoint.
"""
self._set_completion_headers(200)
self.end_headers()
# Scan the cache directory for downloaded mlx models
hf_cache_info = scan_cache_dir()
downloaded_models = [
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
]
# Create a list of available models
models = [
{
"id": repo.repo_id,
"object": "model",
"created": self.created,
}
for repo in downloaded_models
]
response = {"object": "list", "data": models}
response_json = json.dumps(response).encode()
self.wfile.write(response_json)
self.wfile.flush()
def run(
host: str,
port: int,
model_provider: ModelProvider,
server_class=HTTPServer,
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
server_class.address_family, server_address = _get_best_family(*server_address)
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "
"it only implements basic security checks."
)
logging.info(f"Starting httpd at {host} on port {port}...")
httpd.serve_forever()
def main():
parser = argparse.ArgumentParser(description="MLX Http Server.")
parser.add_argument(
"--model",
type=str,
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host for the HTTP server (default: 127.0.0.1)",
)
parser.add_argument(
"--port",
type=int,
default=8080,
help="Port for the HTTP server (default: 8080)",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level (default: INFO)",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
required=False,
)
parser.add_argument(
"--chat-template",
type=str,
default="",
help="Specify a chat template for the tokenizer",
required=False,
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), None),
format="%(asctime)s - %(levelname)s - %(message)s",
)
if args.cache_limit_gb is not None:
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
run(args.host, args.port, ModelProvider(args))
if __name__ == "__main__":
main()