Prompt caching in mlx_lm.server (#1026)

* caching in server

* nits

* fix tests

* don't throw if no metal

* comments
This commit is contained in:
Awni Hannun
2024-10-14 10:57:22 -07:00
committed by GitHub
parent 8dca1a2f60
commit 605c4854f1
4 changed files with 151 additions and 32 deletions

View File

@@ -3,19 +3,38 @@
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
@@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
@@ -156,12 +182,21 @@ class ModelProvider:
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
@@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
@@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
@@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
@@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
)
break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
@@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_stream(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
@@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
@@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
continue
new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
@@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
):
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
@@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
}
return response
def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
@@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
@@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return mx.array(prompt)
return prompt
def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
@@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
return self.tokenizer.encode(self.body["prompt"])
def do_GET(self):
"""
@@ -669,9 +727,16 @@ def run(
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "