Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez
2024-10-16 18:57:55 +02:00
committed by GitHub
18 changed files with 756 additions and 428 deletions

View File

@@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
@@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models
@@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.

View File

@@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
@@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
Returns:
(int): The number of tokens that were trimmed.
"""
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]

View File

@@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
@@ -291,7 +291,7 @@ class MoEGate(nn.Module):
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor

View File

@@ -3,19 +3,38 @@
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
@@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
@@ -156,12 +182,21 @@ class ModelProvider:
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
@@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
@@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
@@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
@@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
)
break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
@@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_stream(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
@@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
@@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
continue
new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
@@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
):
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
@@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
}
return response
def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
@@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
@@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return mx.array(prompt)
return prompt
def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
@@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
return self.tokenizer.encode(self.body["prompt"])
def do_GET(self):
"""
@@ -669,9 +727,16 @@ def run(
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "

View File

@@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
@@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
"""
_byte_decoder = None
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer, trim_space=False):
self.trim_space = trim_space
def __init__(self, tokenizer):
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
@@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
self.text = ""
self.tokens = []
def _maybe_trim_space(self, current_text):
if current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token):
v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32:
current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed
).decode("utf-8")
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = v
else:
self._unflushed += v
@@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8"
)
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = ""
@classmethod
@@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}):

View File

@@ -246,10 +246,10 @@ def generate_step(
y, logprobs = _step(y)
mx.async_eval(y)
mx.async_eval(y, logprobs)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
@@ -348,7 +348,9 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)