Logprobs info to completion API (#806)

* Initial implementation

* Fix handling of return_step_logits in return

* Fixed OpenAI parameter expectations and logprob structure and datatypes

* pre-commit black formatting

* Remove unused parameter

* fix log probs

* fix colorize

* nits in server

* nits in server

* Fix top_logprobs structure (a dict) and include tokens in logprobs response

* nits

* fix types

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-06-23 13:35:13 -04:00 committed by GitHub
parent a7598e9456
commit 1d701a1831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 94 additions and 43 deletions

View File

@ -17,7 +17,7 @@ mlx_lm.server --model <path_to_model_or_hf_repo>
For example: For example:
```shell ```shell
mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1 mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
``` ```
This will start a text generation server on port `8080` of the `localhost` This will start a text generation server on port `8080` of the `localhost`
@ -74,3 +74,7 @@ curl localhost:8080/v1/chat/completions \
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias - `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
values. Defaults to `None`. values. Defaults to `None`.
- `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive.

View File

@ -6,11 +6,13 @@ import logging
import time import time
import uuid import uuid
import warnings import warnings
from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Literal, NamedTuple, Optional, Union from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from transformers import PreTrainedTokenizer
from .tokenizer_utils import TokenizerWrapper from .tokenizer_utils import TokenizerWrapper
from .utils import generate_step, load from .utils import generate_step, load
@ -27,18 +29,22 @@ def stopping_criteria(
eos_token_id: Union[int, None], eos_token_id: Union[int, None],
) -> StopCondition: ) -> StopCondition:
""" """
Determines whether the token generation should stop based on predefined conditions. Determines whether the token generation should stop based on predefined
conditions.
Args: Args:
tokens (List[int]): The current sequence of generated tokens. tokens (List[int]): The current sequence of generated tokens.
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs. stop_id_sequences (List[List[[int]]): A list of integer lists, each
If the end of the `tokens` list matches any of these sequences, the generation should stop. representing a sequence of token IDs. If the end of the `tokens`
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this, list matches any of these sequences, the generation should stop.
the generation should stop. eos_token_id (Union[int, None]): The token ID that represents the
end-of-sequence. If the last token in `tokens` matches this, the
generation should stop.
Returns: Returns:
StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`) StopCondition: A named tuple indicating whether the stop condition has
and how many tokens should be trimmed from the end if it has (`trim_length`). been met (`stop_met`) and how many tokens should be trimmed from the
end if it has (`trim_length`).
""" """
if tokens and tokens[-1] == eos_token_id: if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1) return StopCondition(stop_met=True, trim_length=1)
@ -53,7 +59,10 @@ def stopping_criteria(
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = { default_role_mapping = {
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.", "system_prompt": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant follows the given rules no matter what."
),
"system": "ASSISTANT's RULE: ", "system": "ASSISTANT's RULE: ",
"user": "USER: ", "user": "USER: ",
"assistant": "ASSISTANT: ", "assistant": "ASSISTANT: ",
@ -136,7 +145,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20) self.repetition_context_size = self.body.get("repetition_context_size", 20)
self.logit_bias = self.body.get("logit_bias", None) self.logit_bias = self.body.get("logit_bias", None)
self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters() self.validate_model_parameters()
# Get stop id sequences, if provided # Get stop id sequences, if provided
@ -184,6 +193,11 @@ class APIHandler(BaseHTTPRequestHandler):
): ):
raise ValueError("repetition_penalty must be a non-negative float") raise ValueError("repetition_penalty must be a non-negative float")
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
raise ValueError(
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
)
if ( if (
not isinstance(self.repetition_context_size, int) not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0 or self.repetition_context_size < 0
@ -208,24 +222,34 @@ class APIHandler(BaseHTTPRequestHandler):
finish_reason: Union[Literal["length", "stop"], None], finish_reason: Union[Literal["length", "stop"], None],
prompt_token_count: Optional[int] = None, prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None, completion_token_count: Optional[int] = None,
token_logprobs: Optional[List[float]] = None,
top_tokens: Optional[List[Dict[int, float]]] = None,
tokens: Optional[List[int]] = None,
) -> dict: ) -> dict:
""" """
Generate a single response packet based on response type (stream or not), completion type and parameters. Generate a single response packet based on response type (stream or
not), completion type and parameters.
Args: Args:
text (str): Text generated by model text (str): Text generated by model
finish_reason (Union[Literal["length", "stop"], None]): finish_reason (Union[Literal["length", "stop"], None]): The reason the
The reason the response is being sent: "length", "stop" or None response is being sent: "length", "stop" or `None`.
prompt_token_count (Optional[int]): prompt_token_count (Optional[int]): The number of tokens in the prompt,
The amount of tokens in the prompt, used to populate the "usage" field (not used when stream).
used to populate the "usage" field (not used when stream) completion_token_count (Optional[int]): The number of tokens in the
completion_token_count (Optional[int]): response, used to populate the "usage" field (not used when stream).
The amount of tokens in the response, token_logprobs (Optional[List[float]]): The log probabilities per token,
used to populate the "usage" field (not used when stream) in token order.
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
tokens to logprobs for the top N tokens at each token position.
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
Returns: Returns:
dict: A dictionary containing the response, imitating OpenAI's API dict: A dictionary containing the response, in the same format as
OpenAI's API.
""" """
token_logprobs = token_logprobs if token_logprobs else []
top_logprobs = top_tokens if top_tokens else []
# Static response # Static response
response = { response = {
@ -237,7 +261,11 @@ class APIHandler(BaseHTTPRequestHandler):
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"logprobs": None, "logprobs": {
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"tokens": tokens,
},
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
], ],
@ -281,8 +309,8 @@ class APIHandler(BaseHTTPRequestHandler):
Args: Args:
prompt (mx.array): The prompt, in token form inside of a mlx array prompt (mx.array): The prompt, in token form inside of a mlx array
stop_id_sequences (List[List[int]]): stop_id_sequences (List[List[int]]): A list of stop words passed
A list of stop words passed to the stopping_criteria function to the stopping_criteria function
""" """
detokenizer = self.tokenizer.detokenizer detokenizer = self.tokenizer.detokenizer
detokenizer.reset() detokenizer.reset()
@ -290,7 +318,9 @@ class APIHandler(BaseHTTPRequestHandler):
finish_reason = "length" finish_reason = "length"
stop_sequence_suffix = None stop_sequence_suffix = None
logging.debug(f"Starting completion:") logging.debug(f"Starting completion:")
for (token, _), _ in zip( token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
generate_step( generate_step(
prompt=prompt, prompt=prompt,
model=self.model, model=self.model,
@ -305,6 +335,16 @@ class APIHandler(BaseHTTPRequestHandler):
detokenizer.add_token(token) detokenizer.add_token(token)
logging.debug(detokenizer.text) logging.debug(detokenizer.text)
tokens.append(token) tokens.append(token)
if self.logprobs > 0:
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
token_logprobs.append(logprobs[token].item())
stop_condition = stopping_criteria( stop_condition = stopping_criteria(
tokens, stop_id_sequences, self.tokenizer.eos_token_id tokens, stop_id_sequences, self.tokenizer.eos_token_id
) )
@ -322,7 +362,15 @@ class APIHandler(BaseHTTPRequestHandler):
if stop_sequence_suffix is None if stop_sequence_suffix is None
else detokenizer.text[: -len(stop_sequence_suffix)] else detokenizer.text[: -len(stop_sequence_suffix)]
) )
response = self.generate_response(text, finish_reason, len(prompt), len(tokens)) response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode() response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings indent = "\t" # Backslashes can't be inside of f-strings
@ -458,7 +506,6 @@ class APIHandler(BaseHTTPRequestHandler):
assert "prompt" in self.body, "Request did not contain a prompt" assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"] prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text) prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt) return mx.array(prompt)

View File

@ -149,10 +149,11 @@ def generate_step(
consider for repetition penalty. Default: ``20``. consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words. more less likely words.
logit_bias (dictionary, optional): Additive logit bias.
Yields: Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and probability per call. one token and a vector of log probabilities.
""" """
def sample(logits: mx.array) -> Tuple[mx.array, float]: def sample(logits: mx.array) -> Tuple[mx.array, float]:
@ -160,7 +161,7 @@ def generate_step(
indices = mx.array(list(logit_bias.keys())) indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values())) values = mx.array(list(logit_bias.values()))
logits[:, indices] += values logits[:, indices] += values
softmax_logits = mx.softmax(logits) logprobs = logits - mx.logsumexp(logits)
if temp == 0: if temp == 0:
token = mx.argmax(logits, axis=-1) token = mx.argmax(logits, axis=-1)
@ -170,8 +171,7 @@ def generate_step(
else: else:
token = mx.random.categorical(logits * (1 / temp)) token = mx.random.categorical(logits * (1 / temp))
prob = softmax_logits[0, token] return token, logprobs
return token, prob
if repetition_penalty and ( if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float) repetition_penalty < 0 or not isinstance(repetition_penalty, float)
@ -202,24 +202,24 @@ def generate_step(
logits = apply_repetition_penalty( logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty logits, repetition_context, repetition_penalty
) )
y, prob = sample(logits) y, logprobs = sample(logits)
repetition_context.append(y.item()) repetition_context.append(y.item())
else: else:
y, prob = sample(logits) y, logprobs = sample(logits)
if repetition_context_size: if repetition_context_size:
if len(repetition_context) > repetition_context_size: if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:] repetition_context = repetition_context[-repetition_context_size:]
return y, prob return y, logprobs.squeeze(0)
y, p = _step(y) y, logprobs = _step(y)
mx.async_eval(y) mx.async_eval(y)
while True: while True:
next_y, next_p = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y) mx.async_eval(next_y)
yield y.item(), p yield y.item(), logprobs
y, p = next_y, next_p y, logprobs = next_y, next_logprobs
def stream_generate( def stream_generate(
@ -249,7 +249,7 @@ def stream_generate(
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
detokenizer.reset() detokenizer.reset()
for (token, prob), n in zip( for (token, _), n in zip(
generate_step(prompt_tokens, model, **kwargs), generate_step(prompt_tokens, model, **kwargs),
range(max_tokens), range(max_tokens),
): ):
@ -301,7 +301,7 @@ def generate(
tic = time.perf_counter() tic = time.perf_counter()
detokenizer.reset() detokenizer.reset()
for (token, prob), n in zip( for (token, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs), generate_step(prompt_tokens, model, **kwargs),
range(max_tokens), range(max_tokens),
): ):
@ -316,7 +316,7 @@ def generate(
if formatter: if formatter:
# We have to finalize so that the prob corresponds to the last segment # We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize() detokenizer.finalize()
formatter(detokenizer.last_segment, prob.item()) formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
else: else:
print(detokenizer.last_segment, end="", flush=True) print(detokenizer.last_segment, end="", flush=True)