mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Prompt caching in mlx_lm.server
(#1026)
* caching in server * nits * fix tests * don't throw if no metal * comments
This commit is contained in:
parent
8dca1a2f60
commit
605c4854f1
@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
|
||||
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||
the generated prompt. If not provided, the default mappings are used.
|
||||
|
||||
- `stop`: (Optional) An array of strings or a single string. Thesse are
|
||||
- `stop`: (Optional) An array of strings or a single string. These are
|
||||
sequences of tokens on which the generation should stop.
|
||||
|
||||
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||
@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
|
||||
started in.
|
||||
|
||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||
rlative to the directory the server was started in.
|
||||
relative to the directory the server was started in.
|
||||
|
||||
### Response Fields
|
||||
|
||||
- `id`: A unique identifier for the chat.
|
||||
|
||||
- `system_fingerprint`: A unique identifier for the system.
|
||||
|
||||
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
|
||||
streaming), or "text.completion".
|
||||
|
||||
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
|
||||
|
||||
- `created`: A time-stamp for when the request was processed.
|
||||
|
||||
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
|
||||
- `index`: The index in the list.
|
||||
- `logprobs`: A dictionary containing the fields:
|
||||
- `token_logprobs`: A list of the log probabilities for the generated
|
||||
tokens.
|
||||
- `tokens`: A list of the generated token ids.
|
||||
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
|
||||
top tokens (if requested) with their corresponding probabilities.
|
||||
- `finish_reason`: The reason the completion ended. This can be either of
|
||||
`"stop"` or `"length"`.
|
||||
- `message`: The text response from the model.
|
||||
|
||||
- `usage`: A dictionary containing the fields:
|
||||
- `prompt_tokens`: The number of prompt tokens processed.
|
||||
- `completion_tokens`: The number of tokens generated.
|
||||
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
|
||||
|
||||
### List Models
|
||||
|
||||
@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
|
||||
This will return a list of locally available models where each model in the
|
||||
list contains the following fields:
|
||||
|
||||
- `"id"`: The Hugging Face repo id.
|
||||
- `"created"`: A timestamp representing the model creation time.
|
||||
- `id`: The Hugging Face repo id.
|
||||
- `created`: A time-stamp representing the model creation time.
|
||||
|
@ -3,19 +3,38 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
from ._version import __version__
|
||||
from .models.cache import make_prompt_cache
|
||||
from .utils import generate_step, load
|
||||
|
||||
|
||||
def get_system_fingerprint():
|
||||
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
|
||||
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
|
||||
|
||||
|
||||
class StopCondition(NamedTuple):
|
||||
stop_met: bool
|
||||
trim_length: int
|
||||
@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptCache:
|
||||
cache: List[Any] = field(default_factory=list)
|
||||
model_key: Tuple[str, Optional[str]] = ("", None)
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelProvider:
|
||||
def __init__(self, cli_args: argparse.Namespace):
|
||||
"""Load models on demand and persist them across the whole process."""
|
||||
@ -156,12 +182,21 @@ class ModelProvider:
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: ModelProvider,
|
||||
*args,
|
||||
prompt_cache: Optional[PromptCache] = None,
|
||||
system_fingerprint: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create static request specific metadata
|
||||
"""
|
||||
self.created = int(time.time())
|
||||
self.model_provider = model_provider
|
||||
self.prompt_cache = prompt_cache or PromptCache()
|
||||
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _set_cors_headers(self):
|
||||
@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.stream_options = self.body.get("stream_options", None)
|
||||
self.requested_model = self.body.get("model", "default_model")
|
||||
self.adapter = self.body.get("adapters", None)
|
||||
self.max_tokens = self.body.get("max_tokens", 100)
|
||||
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||
if self.max_tokens is None:
|
||||
self.max_tokens = self.body.get("max_tokens", 512)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||
@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Static response
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"system_fingerprint": self.system_fingerprint,
|
||||
"object": self.object_type,
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
return response
|
||||
|
||||
def get_prompt_cache(self, prompt):
|
||||
cache_len = len(self.prompt_cache.tokens)
|
||||
if (
|
||||
self.prompt_cache.model_key != self.model_provider.model_key
|
||||
or cache_len >= len(prompt)
|
||||
or self.prompt_cache.tokens != prompt[:cache_len]
|
||||
):
|
||||
self.prompt_cache.model_key = self.model_provider.model_key
|
||||
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
|
||||
else:
|
||||
prompt = prompt[cache_len:]
|
||||
self.prompt_cache.tokens.extend(prompt)
|
||||
return prompt
|
||||
|
||||
def handle_completion(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
prompt: List[int],
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate a response to a prompt and send it to the client in a single batch.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
prompt (List[int]): The tokenized prompt.
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||
to the stopping_criteria function
|
||||
"""
|
||||
@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
logging.debug(f"Starting completion:")
|
||||
token_logprobs = []
|
||||
top_tokens = []
|
||||
for (token, logprobs), _ in zip(
|
||||
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
|
||||
for _, (token, logprobs) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
prompt=mx.array(prompt),
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
logit_bias=self.logit_bias,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
top_indices = sorted_indices[: self.logprobs]
|
||||
top_logprobs = logprobs[top_indices]
|
||||
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
||||
top_tokens.append(dict(top_token_info))
|
||||
top_tokens.append(tuple(top_token_info))
|
||||
|
||||
token_logprobs.append(logprobs[token].item())
|
||||
|
||||
@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
)
|
||||
break
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
detokenizer.finalize()
|
||||
text = (
|
||||
detokenizer.text
|
||||
@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def handle_stream(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
prompt: List[int],
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
Sent Events (SSE) stream.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
prompt (mx.array): The tokenized prompt
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
||||
the stopping_criteria function
|
||||
"""
|
||||
@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting stream:")
|
||||
|
||||
for (token, _), _ in zip(
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
|
||||
for _, (token, _) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
prompt=mx.array(prompt),
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
continue
|
||||
|
||||
new_text = detokenizer.last_segment
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
if new_text:
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
|
||||
# check is there any remaining text to send
|
||||
detokenizer.finalize()
|
||||
@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
):
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"system_fingerprint": self.system_fingerprint,
|
||||
"object": "chat.completion",
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
}
|
||||
return response
|
||||
|
||||
def handle_chat_completions(self) -> mx.array:
|
||||
def handle_chat_completions(self) -> List[int]:
|
||||
"""
|
||||
Handle a chat completion request.
|
||||
|
||||
@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.object_type = (
|
||||
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(self.tokenizer, "apply_chat_template")
|
||||
and self.tokenizer.chat_template
|
||||
@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
|
||||
return mx.array(prompt)
|
||||
return prompt
|
||||
|
||||
def handle_text_completions(self) -> mx.array:
|
||||
def handle_text_completions(self) -> List[int]:
|
||||
"""
|
||||
Handle a text completion request.
|
||||
|
||||
@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Determine response type
|
||||
self.request_id = f"cmpl-{uuid.uuid4()}"
|
||||
self.object_type = "text_completion"
|
||||
|
||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||
prompt_text = self.body["prompt"]
|
||||
prompt = self.tokenizer.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
return self.tokenizer.encode(self.body["prompt"])
|
||||
|
||||
def do_GET(self):
|
||||
"""
|
||||
@ -669,9 +727,16 @@ def run(
|
||||
handler_class=APIHandler,
|
||||
):
|
||||
server_address = (host, port)
|
||||
prompt_cache = PromptCache()
|
||||
httpd = server_class(
|
||||
server_address,
|
||||
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
|
||||
lambda *args, **kwargs: handler_class(
|
||||
model_provider,
|
||||
prompt_cache=prompt_cache,
|
||||
system_fingerprint=get_system_fingerprint(),
|
||||
*args,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
warnings.warn(
|
||||
"mlx_lm.server is not recommended for production as "
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase):
|
||||
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||
)
|
||||
|
||||
def test_cache_copying(self):
|
||||
cache = [KVCache()]
|
||||
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
cache[0].update_and_fetch(x, x)
|
||||
|
||||
y = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
cache[0].update_and_fetch(y, y)
|
||||
|
||||
old_cache = copy.deepcopy(cache)
|
||||
|
||||
trim_prompt_cache(cache, 1)
|
||||
|
||||
self.assertTrue(old_cache[0].offset, 11)
|
||||
self.assertTrue(cache[0].offset, 10)
|
||||
|
||||
z = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
cache[0].update_and_fetch(z, z)
|
||||
|
||||
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
||||
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -14,6 +14,7 @@ class DummyModelProvider:
|
||||
def __init__(self):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||
self.model_key = (HF_MODEL_PATH, None)
|
||||
|
||||
def load(self, model, adapter=None):
|
||||
assert model in ["default_model", "chat_model"]
|
||||
|
Loading…
Reference in New Issue
Block a user