mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Prompt caching in mlx_lm.server
(#1026)
* caching in server * nits * fix tests * don't throw if no metal * comments
This commit is contained in:
parent
8dca1a2f60
commit
605c4854f1
@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||||
the generated prompt. If not provided, the default mappings are used.
|
the generated prompt. If not provided, the default mappings are used.
|
||||||
|
|
||||||
- `stop`: (Optional) An array of strings or a single string. Thesse are
|
- `stop`: (Optional) An array of strings or a single string. These are
|
||||||
sequences of tokens on which the generation should stop.
|
sequences of tokens on which the generation should stop.
|
||||||
|
|
||||||
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||||
@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
started in.
|
started in.
|
||||||
|
|
||||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||||
rlative to the directory the server was started in.
|
relative to the directory the server was started in.
|
||||||
|
|
||||||
|
### Response Fields
|
||||||
|
|
||||||
|
- `id`: A unique identifier for the chat.
|
||||||
|
|
||||||
|
- `system_fingerprint`: A unique identifier for the system.
|
||||||
|
|
||||||
|
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
|
||||||
|
streaming), or "text.completion".
|
||||||
|
|
||||||
|
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
|
||||||
|
|
||||||
|
- `created`: A time-stamp for when the request was processed.
|
||||||
|
|
||||||
|
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
|
||||||
|
- `index`: The index in the list.
|
||||||
|
- `logprobs`: A dictionary containing the fields:
|
||||||
|
- `token_logprobs`: A list of the log probabilities for the generated
|
||||||
|
tokens.
|
||||||
|
- `tokens`: A list of the generated token ids.
|
||||||
|
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
|
||||||
|
top tokens (if requested) with their corresponding probabilities.
|
||||||
|
- `finish_reason`: The reason the completion ended. This can be either of
|
||||||
|
`"stop"` or `"length"`.
|
||||||
|
- `message`: The text response from the model.
|
||||||
|
|
||||||
|
- `usage`: A dictionary containing the fields:
|
||||||
|
- `prompt_tokens`: The number of prompt tokens processed.
|
||||||
|
- `completion_tokens`: The number of tokens generated.
|
||||||
|
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
|
||||||
|
|
||||||
### List Models
|
### List Models
|
||||||
|
|
||||||
@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
|
|||||||
This will return a list of locally available models where each model in the
|
This will return a list of locally available models where each model in the
|
||||||
list contains the following fields:
|
list contains the following fields:
|
||||||
|
|
||||||
- `"id"`: The Hugging Face repo id.
|
- `id`: The Hugging Face repo id.
|
||||||
- `"created"`: A timestamp representing the model creation time.
|
- `created`: A time-stamp representing the model creation time.
|
||||||
|
@ -3,19 +3,38 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import platform
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
|
|
||||||
|
from ._version import __version__
|
||||||
|
from .models.cache import make_prompt_cache
|
||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_fingerprint():
|
||||||
|
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
|
||||||
|
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
|
||||||
|
|
||||||
|
|
||||||
class StopCondition(NamedTuple):
|
class StopCondition(NamedTuple):
|
||||||
stop_met: bool
|
stop_met: bool
|
||||||
trim_length: int
|
trim_length: int
|
||||||
@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
|||||||
return prompt.rstrip()
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptCache:
|
||||||
|
cache: List[Any] = field(default_factory=list)
|
||||||
|
model_key: Tuple[str, Optional[str]] = ("", None)
|
||||||
|
tokens: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider:
|
class ModelProvider:
|
||||||
def __init__(self, cli_args: argparse.Namespace):
|
def __init__(self, cli_args: argparse.Namespace):
|
||||||
"""Load models on demand and persist them across the whole process."""
|
"""Load models on demand and persist them across the whole process."""
|
||||||
@ -156,12 +182,21 @@ class ModelProvider:
|
|||||||
|
|
||||||
|
|
||||||
class APIHandler(BaseHTTPRequestHandler):
|
class APIHandler(BaseHTTPRequestHandler):
|
||||||
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_provider: ModelProvider,
|
||||||
|
*args,
|
||||||
|
prompt_cache: Optional[PromptCache] = None,
|
||||||
|
system_fingerprint: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create static request specific metadata
|
Create static request specific metadata
|
||||||
"""
|
"""
|
||||||
self.created = int(time.time())
|
self.created = int(time.time())
|
||||||
self.model_provider = model_provider
|
self.model_provider = model_provider
|
||||||
|
self.prompt_cache = prompt_cache or PromptCache()
|
||||||
|
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _set_cors_headers(self):
|
def _set_cors_headers(self):
|
||||||
@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.stream_options = self.body.get("stream_options", None)
|
self.stream_options = self.body.get("stream_options", None)
|
||||||
self.requested_model = self.body.get("model", "default_model")
|
self.requested_model = self.body.get("model", "default_model")
|
||||||
self.adapter = self.body.get("adapters", None)
|
self.adapter = self.body.get("adapters", None)
|
||||||
self.max_tokens = self.body.get("max_tokens", 100)
|
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||||
|
if self.max_tokens is None:
|
||||||
|
self.max_tokens = self.body.get("max_tokens", 512)
|
||||||
self.temperature = self.body.get("temperature", 1.0)
|
self.temperature = self.body.get("temperature", 1.0)
|
||||||
self.top_p = self.body.get("top_p", 1.0)
|
self.top_p = self.body.get("top_p", 1.0)
|
||||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||||
@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
# Static response
|
# Static response
|
||||||
response = {
|
response = {
|
||||||
"id": self.request_id,
|
"id": self.request_id,
|
||||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
"system_fingerprint": self.system_fingerprint,
|
||||||
"object": self.object_type,
|
"object": self.object_type,
|
||||||
"model": self.requested_model,
|
"model": self.requested_model,
|
||||||
"created": self.created,
|
"created": self.created,
|
||||||
@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def get_prompt_cache(self, prompt):
|
||||||
|
cache_len = len(self.prompt_cache.tokens)
|
||||||
|
if (
|
||||||
|
self.prompt_cache.model_key != self.model_provider.model_key
|
||||||
|
or cache_len >= len(prompt)
|
||||||
|
or self.prompt_cache.tokens != prompt[:cache_len]
|
||||||
|
):
|
||||||
|
self.prompt_cache.model_key = self.model_provider.model_key
|
||||||
|
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
|
||||||
|
else:
|
||||||
|
prompt = prompt[cache_len:]
|
||||||
|
self.prompt_cache.tokens.extend(prompt)
|
||||||
|
return prompt
|
||||||
|
|
||||||
def handle_completion(
|
def handle_completion(
|
||||||
self,
|
self,
|
||||||
prompt: mx.array,
|
prompt: List[int],
|
||||||
stop_id_sequences: List[List[int]],
|
stop_id_sequences: List[List[int]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate a response to a prompt and send it to the client in a single batch.
|
Generate a response to a prompt and send it to the client in a single batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
prompt (List[int]): The tokenized prompt.
|
||||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||||
to the stopping_criteria function
|
to the stopping_criteria function
|
||||||
"""
|
"""
|
||||||
@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
logging.debug(f"Starting completion:")
|
logging.debug(f"Starting completion:")
|
||||||
token_logprobs = []
|
token_logprobs = []
|
||||||
top_tokens = []
|
top_tokens = []
|
||||||
for (token, logprobs), _ in zip(
|
|
||||||
|
prompt = self.get_prompt_cache(prompt)
|
||||||
|
|
||||||
|
for _, (token, logprobs) in zip(
|
||||||
|
range(self.max_tokens),
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompt=mx.array(prompt),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
repetition_context_size=self.repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
|
prompt_cache=self.prompt_cache.cache,
|
||||||
),
|
),
|
||||||
range(self.max_tokens),
|
|
||||||
):
|
):
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
logging.debug(detokenizer.text)
|
logging.debug(detokenizer.text)
|
||||||
@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
top_indices = sorted_indices[: self.logprobs]
|
top_indices = sorted_indices[: self.logprobs]
|
||||||
top_logprobs = logprobs[top_indices]
|
top_logprobs = logprobs[top_indices]
|
||||||
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
||||||
top_tokens.append(dict(top_token_info))
|
top_tokens.append(tuple(top_token_info))
|
||||||
|
|
||||||
token_logprobs.append(logprobs[token].item())
|
token_logprobs.append(logprobs[token].item())
|
||||||
|
|
||||||
@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
self.prompt_cache.tokens.extend(tokens)
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
text = (
|
text = (
|
||||||
detokenizer.text
|
detokenizer.text
|
||||||
@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def handle_stream(
|
def handle_stream(
|
||||||
self,
|
self,
|
||||||
prompt: mx.array,
|
prompt: List[int],
|
||||||
stop_id_sequences: List[List[int]],
|
stop_id_sequences: List[List[int]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
Sent Events (SSE) stream.
|
Sent Events (SSE) stream.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
prompt (mx.array): The tokenized prompt
|
||||||
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
||||||
the stopping_criteria function
|
the stopping_criteria function
|
||||||
"""
|
"""
|
||||||
@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_sequence_suffix = None
|
stop_sequence_suffix = None
|
||||||
logging.debug(f"Starting stream:")
|
logging.debug(f"Starting stream:")
|
||||||
|
|
||||||
for (token, _), _ in zip(
|
prompt = self.get_prompt_cache(prompt)
|
||||||
|
|
||||||
|
for _, (token, _) in zip(
|
||||||
|
range(self.max_tokens),
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompt=mx.array(prompt),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
repetition_context_size=self.repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
|
prompt_cache=self.prompt_cache.cache,
|
||||||
),
|
),
|
||||||
range(self.max_tokens),
|
|
||||||
):
|
):
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
logging.debug(detokenizer.text)
|
logging.debug(detokenizer.text)
|
||||||
@ -531,10 +590,13 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
new_text = detokenizer.last_segment
|
new_text = detokenizer.last_segment
|
||||||
|
if new_text:
|
||||||
response = self.generate_response(new_text, None)
|
response = self.generate_response(new_text, None)
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
|
self.prompt_cache.tokens.extend(tokens)
|
||||||
|
|
||||||
# check is there any remaining text to send
|
# check is there any remaining text to send
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
last_segment = detokenizer.last_segment
|
last_segment = detokenizer.last_segment
|
||||||
@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
):
|
):
|
||||||
response = {
|
response = {
|
||||||
"id": self.request_id,
|
"id": self.request_id,
|
||||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
"system_fingerprint": self.system_fingerprint,
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"model": self.requested_model,
|
"model": self.requested_model,
|
||||||
"created": self.created,
|
"created": self.created,
|
||||||
@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
}
|
}
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def handle_chat_completions(self) -> mx.array:
|
def handle_chat_completions(self) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Handle a chat completion request.
|
Handle a chat completion request.
|
||||||
|
|
||||||
@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.object_type = (
|
self.object_type = (
|
||||||
"chat.completions.chunk" if self.stream else "chat.completions"
|
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.tokenizer, "apply_chat_template")
|
hasattr(self.tokenizer, "apply_chat_template")
|
||||||
and self.tokenizer.chat_template
|
and self.tokenizer.chat_template
|
||||||
@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||||
prompt = self.tokenizer.encode(prompt)
|
prompt = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
return mx.array(prompt)
|
return prompt
|
||||||
|
|
||||||
def handle_text_completions(self) -> mx.array:
|
def handle_text_completions(self) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Handle a text completion request.
|
Handle a text completion request.
|
||||||
|
|
||||||
@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
# Determine response type
|
# Determine response type
|
||||||
self.request_id = f"cmpl-{uuid.uuid4()}"
|
self.request_id = f"cmpl-{uuid.uuid4()}"
|
||||||
self.object_type = "text_completion"
|
self.object_type = "text_completion"
|
||||||
|
|
||||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||||
prompt_text = self.body["prompt"]
|
return self.tokenizer.encode(self.body["prompt"])
|
||||||
prompt = self.tokenizer.encode(prompt_text)
|
|
||||||
return mx.array(prompt)
|
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
"""
|
"""
|
||||||
@ -669,9 +727,16 @@ def run(
|
|||||||
handler_class=APIHandler,
|
handler_class=APIHandler,
|
||||||
):
|
):
|
||||||
server_address = (host, port)
|
server_address = (host, port)
|
||||||
|
prompt_cache = PromptCache()
|
||||||
httpd = server_class(
|
httpd = server_class(
|
||||||
server_address,
|
server_address,
|
||||||
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
|
lambda *args, **kwargs: handler_class(
|
||||||
|
model_provider,
|
||||||
|
prompt_cache=prompt_cache,
|
||||||
|
system_fingerprint=get_system_fingerprint(),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"mlx_lm.server is not recommended for production as "
|
"mlx_lm.server is not recommended for production as "
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_cache_copying(self):
|
||||||
|
cache = [KVCache()]
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
cache[0].update_and_fetch(x, x)
|
||||||
|
|
||||||
|
y = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||||
|
cache[0].update_and_fetch(y, y)
|
||||||
|
|
||||||
|
old_cache = copy.deepcopy(cache)
|
||||||
|
|
||||||
|
trim_prompt_cache(cache, 1)
|
||||||
|
|
||||||
|
self.assertTrue(old_cache[0].offset, 11)
|
||||||
|
self.assertTrue(cache[0].offset, 10)
|
||||||
|
|
||||||
|
z = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||||
|
cache[0].update_and_fetch(z, z)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
||||||
|
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -14,6 +14,7 @@ class DummyModelProvider:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||||
|
self.model_key = (HF_MODEL_PATH, None)
|
||||||
|
|
||||||
def load(self, model, adapter=None):
|
def load(self, model, adapter=None):
|
||||||
assert model in ["default_model", "chat_model"]
|
assert model in ["default_model", "chat_model"]
|
||||||
|
Loading…
Reference in New Issue
Block a user