Logprobs info to completion API (#806)

* Initial implementation

* Fix handling of return_step_logits in return

* Fixed OpenAI parameter expectations and logprob structure and datatypes

* pre-commit black formatting

* Remove unused parameter

* fix log probs

* fix colorize

* nits in server

* nits in server

* Fix top_logprobs structure (a dict) and include tokens in logprobs response

* nits

* fix types

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-06-23 13:35:13 -04:00 committed by GitHub
parent a7598e9456
commit 1d701a1831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 94 additions and 43 deletions

View File

@ -17,7 +17,7 @@ mlx_lm.server --model <path_to_model_or_hf_repo>
For example:
```shell
mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
```
This will start a text generation server on port `8080` of the `localhost`
@ -73,4 +73,8 @@ curl localhost:8080/v1/chat/completions \
applying repetition penalty. Defaults to `20`.
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
values. Defaults to `None`.
values. Defaults to `None`.
- `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive.

View File

@ -6,11 +6,13 @@ import logging
import time
import uuid
import warnings
from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Literal, NamedTuple, Optional, Union
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from transformers import PreTrainedTokenizer
from .tokenizer_utils import TokenizerWrapper
from .utils import generate_step, load
@ -27,18 +29,22 @@ def stopping_criteria(
eos_token_id: Union[int, None],
) -> StopCondition:
"""
Determines whether the token generation should stop based on predefined conditions.
Determines whether the token generation should stop based on predefined
conditions.
Args:
tokens (List[int]): The current sequence of generated tokens.
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
If the end of the `tokens` list matches any of these sequences, the generation should stop.
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
the generation should stop.
stop_id_sequences (List[List[[int]]): A list of integer lists, each
representing a sequence of token IDs. If the end of the `tokens`
list matches any of these sequences, the generation should stop.
eos_token_id (Union[int, None]): The token ID that represents the
end-of-sequence. If the last token in `tokens` matches this, the
generation should stop.
Returns:
StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`)
and how many tokens should be trimmed from the end if it has (`trim_length`).
StopCondition: A named tuple indicating whether the stop condition has
been met (`stop_met`) and how many tokens should be trimmed from the
end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1)
@ -53,7 +59,10 @@ def stopping_criteria(
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
"system_prompt": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant follows the given rules no matter what."
),
"system": "ASSISTANT's RULE: ",
"user": "USER: ",
"assistant": "ASSISTANT: ",
@ -136,7 +145,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
self.logit_bias = self.body.get("logit_bias", None)
self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters()
# Get stop id sequences, if provided
@ -184,6 +193,11 @@ class APIHandler(BaseHTTPRequestHandler):
):
raise ValueError("repetition_penalty must be a non-negative float")
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
raise ValueError(
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
)
if (
not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0
@ -208,24 +222,34 @@ class APIHandler(BaseHTTPRequestHandler):
finish_reason: Union[Literal["length", "stop"], None],
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
token_logprobs: Optional[List[float]] = None,
top_tokens: Optional[List[Dict[int, float]]] = None,
tokens: Optional[List[int]] = None,
) -> dict:
"""
Generate a single response packet based on response type (stream or not), completion type and parameters.
Generate a single response packet based on response type (stream or
not), completion type and parameters.
Args:
text (str): Text generated by model
finish_reason (Union[Literal["length", "stop"], None]):
The reason the response is being sent: "length", "stop" or None
prompt_token_count (Optional[int]):
The amount of tokens in the prompt,
used to populate the "usage" field (not used when stream)
completion_token_count (Optional[int]):
The amount of tokens in the response,
used to populate the "usage" field (not used when stream)
finish_reason (Union[Literal["length", "stop"], None]): The reason the
response is being sent: "length", "stop" or `None`.
prompt_token_count (Optional[int]): The number of tokens in the prompt,
used to populate the "usage" field (not used when stream).
completion_token_count (Optional[int]): The number of tokens in the
response, used to populate the "usage" field (not used when stream).
token_logprobs (Optional[List[float]]): The log probabilities per token,
in token order.
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
tokens to logprobs for the top N tokens at each token position.
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
Returns:
dict: A dictionary containing the response, imitating OpenAI's API
dict: A dictionary containing the response, in the same format as
OpenAI's API.
"""
token_logprobs = token_logprobs if token_logprobs else []
top_logprobs = top_tokens if top_tokens else []
# Static response
response = {
@ -237,7 +261,11 @@ class APIHandler(BaseHTTPRequestHandler):
"choices": [
{
"index": 0,
"logprobs": None,
"logprobs": {
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"tokens": tokens,
},
"finish_reason": finish_reason,
}
],
@ -281,8 +309,8 @@ class APIHandler(BaseHTTPRequestHandler):
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
@ -290,7 +318,9 @@ class APIHandler(BaseHTTPRequestHandler):
finish_reason = "length"
stop_sequence_suffix = None
logging.debug(f"Starting completion:")
for (token, _), _ in zip(
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
generate_step(
prompt=prompt,
model=self.model,
@ -305,6 +335,16 @@ class APIHandler(BaseHTTPRequestHandler):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
if self.logprobs > 0:
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
token_logprobs.append(logprobs[token].item())
stop_condition = stopping_criteria(
tokens, stop_id_sequences, self.tokenizer.eos_token_id
)
@ -322,7 +362,15 @@ class APIHandler(BaseHTTPRequestHandler):
if stop_sequence_suffix is None
else detokenizer.text[: -len(stop_sequence_suffix)]
)
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
@ -458,7 +506,6 @@ class APIHandler(BaseHTTPRequestHandler):
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)

View File

@ -149,10 +149,11 @@ def generate_step(
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
logit_bias (dictionary, optional): Additive logit bias.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call.
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
@ -160,7 +161,7 @@ def generate_step(
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
softmax_logits = mx.softmax(logits)
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
@ -170,8 +171,7 @@ def generate_step(
else:
token = mx.random.categorical(logits * (1 / temp))
prob = softmax_logits[0, token]
return token, prob
return token, logprobs
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
@ -202,24 +202,24 @@ def generate_step(
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
y, prob = sample(logits)
y, logprobs = sample(logits)
repetition_context.append(y.item())
else:
y, prob = sample(logits)
y, logprobs = sample(logits)
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
return y, prob
return y, logprobs.squeeze(0)
y, p = _step(y)
y, logprobs = _step(y)
mx.async_eval(y)
while True:
next_y, next_p = _step(y)
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
yield y.item(), p
y, p = next_y, next_p
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
def stream_generate(
@ -249,7 +249,7 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, prob), n in zip(
for (token, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
@ -301,7 +301,7 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
for (token, prob), n in zip(
for (token, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
@ -316,7 +316,7 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, prob.item())
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
else:
print(detokenizer.last_segment, end="", flush=True)