Prompt caching in mlx_lm.server (#1026)

* caching in server

* nits

* fix tests

* don't throw if no metal

* comments
This commit is contained in:
Awni Hannun 2024-10-14 10:57:22 -07:00 committed by GitHub
parent 8dca1a2f60
commit 605c4854f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 151 additions and 32 deletions

View File

@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in - `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used. the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are - `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop. sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens - `max_tokens`: (Optional) An integer specifying the maximum number of tokens
@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
started in. started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be - `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in. relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models ### List Models
@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
This will return a list of locally available models where each model in the This will return a list of locally available models where each model in the
list contains the following fields: list contains the following fields:
- `"id"`: The Hugging Face repo id. - `id`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time. - `created`: A time-stamp representing the model creation time.

View File

@ -3,19 +3,38 @@
import argparse import argparse
import json import json
import logging import logging
import platform
import time import time
import uuid import uuid
import warnings import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx import mlx.core as mx
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load from .utils import generate_step, load
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple): class StopCondition(NamedTuple):
stop_met: bool stop_met: bool
trim_length: int trim_length: int
@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider: class ModelProvider:
def __init__(self, cli_args: argparse.Namespace): def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process.""" """Load models on demand and persist them across the whole process."""
@ -156,12 +182,21 @@ class ModelProvider:
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs): def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
""" """
Create static request specific metadata Create static request specific metadata
""" """
self.created = int(time.time()) self.created = int(time.time())
self.model_provider = model_provider self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _set_cors_headers(self): def _set_cors_headers(self):
@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream_options = self.body.get("stream_options", None) self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model") self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None) self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100) self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0) self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0) self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response # Static response
response = { response = {
"id": self.request_id, "id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}", "system_fingerprint": self.system_fingerprint,
"object": self.object_type, "object": self.object_type,
"model": self.requested_model, "model": self.requested_model,
"created": self.created, "created": self.created,
@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
return response return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion( def handle_completion(
self, self,
prompt: mx.array, prompt: List[int],
stop_id_sequences: List[List[int]], stop_id_sequences: List[List[int]],
): ):
""" """
Generate a response to a prompt and send it to the client in a single batch. Generate a response to a prompt and send it to the client in a single batch.
Args: Args:
prompt (mx.array): The prompt, in token form inside of a mlx array prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function to the stopping_criteria function
""" """
@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:") logging.debug(f"Starting completion:")
token_logprobs = [] token_logprobs = []
top_tokens = [] top_tokens = []
for (token, logprobs), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step( generate_step(
prompt=prompt, prompt=mx.array(prompt),
model=self.model, model=self.model,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p, top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size, repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
), ),
range(self.max_tokens),
): ):
detokenizer.add_token(token) detokenizer.add_token(token)
logging.debug(detokenizer.text) logging.debug(detokenizer.text)
@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_indices = sorted_indices[: self.logprobs] top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices] top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info)) top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item()) token_logprobs.append(logprobs[token].item())
@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
) )
break break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize() detokenizer.finalize()
text = ( text = (
detokenizer.text detokenizer.text
@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_stream( def handle_stream(
self, self,
prompt: mx.array, prompt: List[int],
stop_id_sequences: List[List[int]], stop_id_sequences: List[List[int]],
): ):
""" """
@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
Sent Events (SSE) stream. Sent Events (SSE) stream.
Args: Args:
prompt (mx.array): The prompt, in token form inside of a mlx array prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function the stopping_criteria function
""" """
@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None stop_sequence_suffix = None
logging.debug(f"Starting stream:") logging.debug(f"Starting stream:")
for (token, _), _ in zip( prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step( generate_step(
prompt=prompt, prompt=mx.array(prompt),
model=self.model, model=self.model,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p, top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size, repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
), ),
range(self.max_tokens),
): ):
detokenizer.add_token(token) detokenizer.add_token(token)
logging.debug(detokenizer.text) logging.debug(detokenizer.text)
@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
continue continue
new_text = detokenizer.last_segment new_text = detokenizer.last_segment
response = self.generate_response(new_text, None) if new_text:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) response = self.generate_response(new_text, None)
self.wfile.flush() self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send # check is there any remaining text to send
detokenizer.finalize() detokenizer.finalize()
@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
): ):
response = { response = {
"id": self.request_id, "id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}", "system_fingerprint": self.system_fingerprint,
"object": "chat.completion", "object": "chat.completion",
"model": self.requested_model, "model": self.requested_model,
"created": self.created, "created": self.created,
@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
} }
return response return response
def handle_chat_completions(self) -> mx.array: def handle_chat_completions(self) -> List[int]:
""" """
Handle a chat completion request. Handle a chat completion request.
@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = ( self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions" "chat.completions.chunk" if self.stream else "chat.completions"
) )
if ( if (
hasattr(self.tokenizer, "apply_chat_template") hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template and self.tokenizer.chat_template
@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt) prompt = self.tokenizer.encode(prompt)
return mx.array(prompt) return prompt
def handle_text_completions(self) -> mx.array: def handle_text_completions(self) -> List[int]:
""" """
Handle a text completion request. Handle a text completion request.
@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type # Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}" self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion" self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt" assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"] return self.tokenizer.encode(self.body["prompt"])
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
def do_GET(self): def do_GET(self):
""" """
@ -669,9 +727,16 @@ def run(
handler_class=APIHandler, handler_class=APIHandler,
): ):
server_address = (host, port) server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class( httpd = server_class(
server_address, server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
) )
warnings.warn( warnings.warn(
"mlx_lm.server is not recommended for production as " "mlx_lm.server is not recommended for production as "

View File

@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import copy
import os import os
import tempfile import tempfile
import unittest import unittest
@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase):
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
) )
def test_cache_copying(self):
cache = [KVCache()]
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[0].update_and_fetch(x, x)
y = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(y, y)
old_cache = copy.deepcopy(cache)
trim_prompt_cache(cache, 1)
self.assertTrue(old_cache[0].offset, 11)
self.assertTrue(cache[0].offset, 10)
z = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(z, z)
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -14,6 +14,7 @@ class DummyModelProvider:
def __init__(self): def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH) self.model, self.tokenizer = load(HF_MODEL_PATH)
self.model_key = (HF_MODEL_PATH, None)
def load(self, model, adapter=None): def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"] assert model in ["default_model", "chat_model"]