mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Logprobs info to completion API (#806)
* Initial implementation * Fix handling of return_step_logits in return * Fixed OpenAI parameter expectations and logprob structure and datatypes * pre-commit black formatting * Remove unused parameter * fix log probs * fix colorize * nits in server * nits in server * Fix top_logprobs structure (a dict) and include tokens in logprobs response * nits * fix types --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
a7598e9456
commit
1d701a1831
@ -17,7 +17,7 @@ mlx_lm.server --model <path_to_model_or_hf_repo>
|
|||||||
For example:
|
For example:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
|
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
|
||||||
```
|
```
|
||||||
|
|
||||||
This will start a text generation server on port `8080` of the `localhost`
|
This will start a text generation server on port `8080` of the `localhost`
|
||||||
@ -74,3 +74,7 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
|
|
||||||
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||||
values. Defaults to `None`.
|
values. Defaults to `None`.
|
||||||
|
|
||||||
|
- `logprobs`: (Optional) An integer specifying the number of top tokens and
|
||||||
|
corresponding log probabilities to return for each output in the generated
|
||||||
|
sequence. If set, this can be any value between 1 and 10, inclusive.
|
||||||
|
@ -6,11 +6,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import lru_cache
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from typing import List, Literal, NamedTuple, Optional, Union
|
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from .tokenizer_utils import TokenizerWrapper
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
@ -27,18 +29,22 @@ def stopping_criteria(
|
|||||||
eos_token_id: Union[int, None],
|
eos_token_id: Union[int, None],
|
||||||
) -> StopCondition:
|
) -> StopCondition:
|
||||||
"""
|
"""
|
||||||
Determines whether the token generation should stop based on predefined conditions.
|
Determines whether the token generation should stop based on predefined
|
||||||
|
conditions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens (List[int]): The current sequence of generated tokens.
|
tokens (List[int]): The current sequence of generated tokens.
|
||||||
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
|
stop_id_sequences (List[List[[int]]): A list of integer lists, each
|
||||||
If the end of the `tokens` list matches any of these sequences, the generation should stop.
|
representing a sequence of token IDs. If the end of the `tokens`
|
||||||
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
|
list matches any of these sequences, the generation should stop.
|
||||||
the generation should stop.
|
eos_token_id (Union[int, None]): The token ID that represents the
|
||||||
|
end-of-sequence. If the last token in `tokens` matches this, the
|
||||||
|
generation should stop.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`)
|
StopCondition: A named tuple indicating whether the stop condition has
|
||||||
and how many tokens should be trimmed from the end if it has (`trim_length`).
|
been met (`stop_met`) and how many tokens should be trimmed from the
|
||||||
|
end if it has (`trim_length`).
|
||||||
"""
|
"""
|
||||||
if tokens and tokens[-1] == eos_token_id:
|
if tokens and tokens[-1] == eos_token_id:
|
||||||
return StopCondition(stop_met=True, trim_length=1)
|
return StopCondition(stop_met=True, trim_length=1)
|
||||||
@ -53,7 +59,10 @@ def stopping_criteria(
|
|||||||
|
|
||||||
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||||
default_role_mapping = {
|
default_role_mapping = {
|
||||||
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
|
"system_prompt": (
|
||||||
|
"A chat between a curious user and an artificial intelligence "
|
||||||
|
"assistant. The assistant follows the given rules no matter what."
|
||||||
|
),
|
||||||
"system": "ASSISTANT's RULE: ",
|
"system": "ASSISTANT's RULE: ",
|
||||||
"user": "USER: ",
|
"user": "USER: ",
|
||||||
"assistant": "ASSISTANT: ",
|
"assistant": "ASSISTANT: ",
|
||||||
@ -136,7 +145,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||||
self.logit_bias = self.body.get("logit_bias", None)
|
self.logit_bias = self.body.get("logit_bias", None)
|
||||||
|
self.logprobs = self.body.get("logprobs", -1)
|
||||||
self.validate_model_parameters()
|
self.validate_model_parameters()
|
||||||
|
|
||||||
# Get stop id sequences, if provided
|
# Get stop id sequences, if provided
|
||||||
@ -184,6 +193,11 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
):
|
):
|
||||||
raise ValueError("repetition_penalty must be a non-negative float")
|
raise ValueError("repetition_penalty must be a non-negative float")
|
||||||
|
|
||||||
|
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
|
||||||
|
raise ValueError(
|
||||||
|
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not isinstance(self.repetition_context_size, int)
|
not isinstance(self.repetition_context_size, int)
|
||||||
or self.repetition_context_size < 0
|
or self.repetition_context_size < 0
|
||||||
@ -208,24 +222,34 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
finish_reason: Union[Literal["length", "stop"], None],
|
finish_reason: Union[Literal["length", "stop"], None],
|
||||||
prompt_token_count: Optional[int] = None,
|
prompt_token_count: Optional[int] = None,
|
||||||
completion_token_count: Optional[int] = None,
|
completion_token_count: Optional[int] = None,
|
||||||
|
token_logprobs: Optional[List[float]] = None,
|
||||||
|
top_tokens: Optional[List[Dict[int, float]]] = None,
|
||||||
|
tokens: Optional[List[int]] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Generate a single response packet based on response type (stream or not), completion type and parameters.
|
Generate a single response packet based on response type (stream or
|
||||||
|
not), completion type and parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): Text generated by model
|
text (str): Text generated by model
|
||||||
finish_reason (Union[Literal["length", "stop"], None]):
|
finish_reason (Union[Literal["length", "stop"], None]): The reason the
|
||||||
The reason the response is being sent: "length", "stop" or None
|
response is being sent: "length", "stop" or `None`.
|
||||||
prompt_token_count (Optional[int]):
|
prompt_token_count (Optional[int]): The number of tokens in the prompt,
|
||||||
The amount of tokens in the prompt,
|
used to populate the "usage" field (not used when stream).
|
||||||
used to populate the "usage" field (not used when stream)
|
completion_token_count (Optional[int]): The number of tokens in the
|
||||||
completion_token_count (Optional[int]):
|
response, used to populate the "usage" field (not used when stream).
|
||||||
The amount of tokens in the response,
|
token_logprobs (Optional[List[float]]): The log probabilities per token,
|
||||||
used to populate the "usage" field (not used when stream)
|
in token order.
|
||||||
|
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
|
||||||
|
tokens to logprobs for the top N tokens at each token position.
|
||||||
|
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the response, imitating OpenAI's API
|
dict: A dictionary containing the response, in the same format as
|
||||||
|
OpenAI's API.
|
||||||
"""
|
"""
|
||||||
|
token_logprobs = token_logprobs if token_logprobs else []
|
||||||
|
top_logprobs = top_tokens if top_tokens else []
|
||||||
|
|
||||||
# Static response
|
# Static response
|
||||||
response = {
|
response = {
|
||||||
@ -237,7 +261,11 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": {
|
||||||
|
"token_logprobs": token_logprobs,
|
||||||
|
"top_logprobs": top_logprobs,
|
||||||
|
"tokens": tokens,
|
||||||
|
},
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -281,8 +309,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||||
stop_id_sequences (List[List[int]]):
|
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||||
A list of stop words passed to the stopping_criteria function
|
to the stopping_criteria function
|
||||||
"""
|
"""
|
||||||
detokenizer = self.tokenizer.detokenizer
|
detokenizer = self.tokenizer.detokenizer
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
@ -290,7 +318,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
stop_sequence_suffix = None
|
stop_sequence_suffix = None
|
||||||
logging.debug(f"Starting completion:")
|
logging.debug(f"Starting completion:")
|
||||||
for (token, _), _ in zip(
|
token_logprobs = []
|
||||||
|
top_tokens = []
|
||||||
|
for (token, logprobs), _ in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@ -305,6 +335,16 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
logging.debug(detokenizer.text)
|
logging.debug(detokenizer.text)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
|
if self.logprobs > 0:
|
||||||
|
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
|
||||||
|
top_indices = sorted_indices[: self.logprobs]
|
||||||
|
top_logprobs = logprobs[top_indices]
|
||||||
|
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
||||||
|
top_tokens.append(dict(top_token_info))
|
||||||
|
|
||||||
|
token_logprobs.append(logprobs[token].item())
|
||||||
|
|
||||||
stop_condition = stopping_criteria(
|
stop_condition = stopping_criteria(
|
||||||
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||||
)
|
)
|
||||||
@ -322,7 +362,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
if stop_sequence_suffix is None
|
if stop_sequence_suffix is None
|
||||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||||
)
|
)
|
||||||
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
|
response = self.generate_response(
|
||||||
|
text,
|
||||||
|
finish_reason,
|
||||||
|
len(prompt),
|
||||||
|
len(tokens),
|
||||||
|
token_logprobs=token_logprobs,
|
||||||
|
top_tokens=top_tokens,
|
||||||
|
tokens=tokens,
|
||||||
|
)
|
||||||
|
|
||||||
response_json = json.dumps(response).encode()
|
response_json = json.dumps(response).encode()
|
||||||
indent = "\t" # Backslashes can't be inside of f-strings
|
indent = "\t" # Backslashes can't be inside of f-strings
|
||||||
@ -458,7 +506,6 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||||
prompt_text = self.body["prompt"]
|
prompt_text = self.body["prompt"]
|
||||||
|
|
||||||
prompt = self.tokenizer.encode(prompt_text)
|
prompt = self.tokenizer.encode(prompt_text)
|
||||||
return mx.array(prompt)
|
return mx.array(prompt)
|
||||||
|
|
||||||
|
@ -149,10 +149,11 @@ def generate_step(
|
|||||||
consider for repetition penalty. Default: ``20``.
|
consider for repetition penalty. Default: ``20``.
|
||||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
more less likely words.
|
more less likely words.
|
||||||
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
one token and probability per call.
|
one token and a vector of log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||||
@ -160,7 +161,7 @@ def generate_step(
|
|||||||
indices = mx.array(list(logit_bias.keys()))
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
values = mx.array(list(logit_bias.values()))
|
values = mx.array(list(logit_bias.values()))
|
||||||
logits[:, indices] += values
|
logits[:, indices] += values
|
||||||
softmax_logits = mx.softmax(logits)
|
logprobs = logits - mx.logsumexp(logits)
|
||||||
|
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
token = mx.argmax(logits, axis=-1)
|
token = mx.argmax(logits, axis=-1)
|
||||||
@ -170,8 +171,7 @@ def generate_step(
|
|||||||
else:
|
else:
|
||||||
token = mx.random.categorical(logits * (1 / temp))
|
token = mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
prob = softmax_logits[0, token]
|
return token, logprobs
|
||||||
return token, prob
|
|
||||||
|
|
||||||
if repetition_penalty and (
|
if repetition_penalty and (
|
||||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||||
@ -202,24 +202,24 @@ def generate_step(
|
|||||||
logits = apply_repetition_penalty(
|
logits = apply_repetition_penalty(
|
||||||
logits, repetition_context, repetition_penalty
|
logits, repetition_context, repetition_penalty
|
||||||
)
|
)
|
||||||
y, prob = sample(logits)
|
y, logprobs = sample(logits)
|
||||||
repetition_context.append(y.item())
|
repetition_context.append(y.item())
|
||||||
else:
|
else:
|
||||||
y, prob = sample(logits)
|
y, logprobs = sample(logits)
|
||||||
|
|
||||||
if repetition_context_size:
|
if repetition_context_size:
|
||||||
if len(repetition_context) > repetition_context_size:
|
if len(repetition_context) > repetition_context_size:
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
return y, prob
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
y, p = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y)
|
mx.async_eval(y)
|
||||||
while True:
|
while True:
|
||||||
next_y, next_p = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y)
|
mx.async_eval(next_y)
|
||||||
yield y.item(), p
|
yield y.item(), logprobs
|
||||||
y, p = next_y, next_p
|
y, logprobs = next_y, next_logprobs
|
||||||
|
|
||||||
|
|
||||||
def stream_generate(
|
def stream_generate(
|
||||||
@ -249,7 +249,7 @@ def stream_generate(
|
|||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
for (token, prob), n in zip(
|
for (token, _), n in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
@ -301,7 +301,7 @@ def generate(
|
|||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
|
|
||||||
for (token, prob), n in zip(
|
for (token, logprobs), n in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
@ -316,7 +316,7 @@ def generate(
|
|||||||
if formatter:
|
if formatter:
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
formatter(detokenizer.last_segment, prob.item())
|
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||||
else:
|
else:
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user