mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Logprobs info to completion API (#806)
* Initial implementation * Fix handling of return_step_logits in return * Fixed OpenAI parameter expectations and logprob structure and datatypes * pre-commit black formatting * Remove unused parameter * fix log probs * fix colorize * nits in server * nits in server * Fix top_logprobs structure (a dict) and include tokens in logprobs response * nits * fix types --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
a7598e9456
commit
1d701a1831
@ -17,7 +17,7 @@ mlx_lm.server --model <path_to_model_or_hf_repo>
|
||||
For example:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
|
||||
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
|
||||
```
|
||||
|
||||
This will start a text generation server on port `8080` of the `localhost`
|
||||
@ -74,3 +74,7 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||
values. Defaults to `None`.
|
||||
|
||||
- `logprobs`: (Optional) An integer specifying the number of top tokens and
|
||||
corresponding log probabilities to return for each output in the generated
|
||||
sequence. If set, this can be any value between 1 and 10, inclusive.
|
||||
|
@ -6,11 +6,13 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .utils import generate_step, load
|
||||
@ -27,18 +29,22 @@ def stopping_criteria(
|
||||
eos_token_id: Union[int, None],
|
||||
) -> StopCondition:
|
||||
"""
|
||||
Determines whether the token generation should stop based on predefined conditions.
|
||||
Determines whether the token generation should stop based on predefined
|
||||
conditions.
|
||||
|
||||
Args:
|
||||
tokens (List[int]): The current sequence of generated tokens.
|
||||
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
|
||||
If the end of the `tokens` list matches any of these sequences, the generation should stop.
|
||||
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
|
||||
the generation should stop.
|
||||
stop_id_sequences (List[List[[int]]): A list of integer lists, each
|
||||
representing a sequence of token IDs. If the end of the `tokens`
|
||||
list matches any of these sequences, the generation should stop.
|
||||
eos_token_id (Union[int, None]): The token ID that represents the
|
||||
end-of-sequence. If the last token in `tokens` matches this, the
|
||||
generation should stop.
|
||||
|
||||
Returns:
|
||||
StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`)
|
||||
and how many tokens should be trimmed from the end if it has (`trim_length`).
|
||||
StopCondition: A named tuple indicating whether the stop condition has
|
||||
been met (`stop_met`) and how many tokens should be trimmed from the
|
||||
end if it has (`trim_length`).
|
||||
"""
|
||||
if tokens and tokens[-1] == eos_token_id:
|
||||
return StopCondition(stop_met=True, trim_length=1)
|
||||
@ -53,7 +59,10 @@ def stopping_criteria(
|
||||
|
||||
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
default_role_mapping = {
|
||||
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
|
||||
"system_prompt": (
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant follows the given rules no matter what."
|
||||
),
|
||||
"system": "ASSISTANT's RULE: ",
|
||||
"user": "USER: ",
|
||||
"assistant": "ASSISTANT: ",
|
||||
@ -136,7 +145,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||
self.logit_bias = self.body.get("logit_bias", None)
|
||||
|
||||
self.logprobs = self.body.get("logprobs", -1)
|
||||
self.validate_model_parameters()
|
||||
|
||||
# Get stop id sequences, if provided
|
||||
@ -184,6 +193,11 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
):
|
||||
raise ValueError("repetition_penalty must be a non-negative float")
|
||||
|
||||
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
|
||||
raise ValueError(
|
||||
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(self.repetition_context_size, int)
|
||||
or self.repetition_context_size < 0
|
||||
@ -208,24 +222,34 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
finish_reason: Union[Literal["length", "stop"], None],
|
||||
prompt_token_count: Optional[int] = None,
|
||||
completion_token_count: Optional[int] = None,
|
||||
token_logprobs: Optional[List[float]] = None,
|
||||
top_tokens: Optional[List[Dict[int, float]]] = None,
|
||||
tokens: Optional[List[int]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a single response packet based on response type (stream or not), completion type and parameters.
|
||||
Generate a single response packet based on response type (stream or
|
||||
not), completion type and parameters.
|
||||
|
||||
Args:
|
||||
text (str): Text generated by model
|
||||
finish_reason (Union[Literal["length", "stop"], None]):
|
||||
The reason the response is being sent: "length", "stop" or None
|
||||
prompt_token_count (Optional[int]):
|
||||
The amount of tokens in the prompt,
|
||||
used to populate the "usage" field (not used when stream)
|
||||
completion_token_count (Optional[int]):
|
||||
The amount of tokens in the response,
|
||||
used to populate the "usage" field (not used when stream)
|
||||
finish_reason (Union[Literal["length", "stop"], None]): The reason the
|
||||
response is being sent: "length", "stop" or `None`.
|
||||
prompt_token_count (Optional[int]): The number of tokens in the prompt,
|
||||
used to populate the "usage" field (not used when stream).
|
||||
completion_token_count (Optional[int]): The number of tokens in the
|
||||
response, used to populate the "usage" field (not used when stream).
|
||||
token_logprobs (Optional[List[float]]): The log probabilities per token,
|
||||
in token order.
|
||||
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
|
||||
tokens to logprobs for the top N tokens at each token position.
|
||||
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the response, imitating OpenAI's API
|
||||
dict: A dictionary containing the response, in the same format as
|
||||
OpenAI's API.
|
||||
"""
|
||||
token_logprobs = token_logprobs if token_logprobs else []
|
||||
top_logprobs = top_tokens if top_tokens else []
|
||||
|
||||
# Static response
|
||||
response = {
|
||||
@ -237,7 +261,11 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": {
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"tokens": tokens,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
@ -281,8 +309,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
stop_id_sequences (List[List[int]]):
|
||||
A list of stop words passed to the stopping_criteria function
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||
to the stopping_criteria function
|
||||
"""
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
@ -290,7 +318,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
finish_reason = "length"
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting completion:")
|
||||
for (token, _), _ in zip(
|
||||
token_logprobs = []
|
||||
top_tokens = []
|
||||
for (token, logprobs), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
@ -305,6 +335,16 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
|
||||
if self.logprobs > 0:
|
||||
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
|
||||
top_indices = sorted_indices[: self.logprobs]
|
||||
top_logprobs = logprobs[top_indices]
|
||||
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
||||
top_tokens.append(dict(top_token_info))
|
||||
|
||||
token_logprobs.append(logprobs[token].item())
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||
)
|
||||
@ -322,7 +362,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
|
||||
response = self.generate_response(
|
||||
text,
|
||||
finish_reason,
|
||||
len(prompt),
|
||||
len(tokens),
|
||||
token_logprobs=token_logprobs,
|
||||
top_tokens=top_tokens,
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
indent = "\t" # Backslashes can't be inside of f-strings
|
||||
@ -458,7 +506,6 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||
prompt_text = self.body["prompt"]
|
||||
|
||||
prompt = self.tokenizer.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
|
||||
|
@ -149,10 +149,11 @@ def generate_step(
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||
one token and probability per call.
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
one token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
@ -160,7 +161,7 @@ def generate_step(
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
softmax_logits = mx.softmax(logits)
|
||||
logprobs = logits - mx.logsumexp(logits)
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
@ -170,8 +171,7 @@ def generate_step(
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
prob = softmax_logits[0, token]
|
||||
return token, prob
|
||||
return token, logprobs
|
||||
|
||||
if repetition_penalty and (
|
||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||
@ -202,24 +202,24 @@ def generate_step(
|
||||
logits = apply_repetition_penalty(
|
||||
logits, repetition_context, repetition_penalty
|
||||
)
|
||||
y, prob = sample(logits)
|
||||
y, logprobs = sample(logits)
|
||||
repetition_context.append(y.item())
|
||||
else:
|
||||
y, prob = sample(logits)
|
||||
y, logprobs = sample(logits)
|
||||
|
||||
if repetition_context_size:
|
||||
if len(repetition_context) > repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
return y, prob
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
y, p = _step(y)
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
while True:
|
||||
next_y, next_p = _step(y)
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
yield y.item(), p
|
||||
y, p = next_y, next_p
|
||||
yield y.item(), logprobs
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
||||
|
||||
def stream_generate(
|
||||
@ -249,7 +249,7 @@ def stream_generate(
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
detokenizer.reset()
|
||||
for (token, prob), n in zip(
|
||||
for (token, _), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
@ -301,7 +301,7 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, prob), n in zip(
|
||||
for (token, logprobs), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
@ -316,7 +316,7 @@ def generate(
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
formatter(detokenizer.last_segment, prob.item())
|
||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user