From 1d701a1831a08407c670a43b77d28e843d033551 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 23 Jun 2024 13:35:13 -0400 Subject: [PATCH] Logprobs info to completion API (#806) * Initial implementation * Fix handling of return_step_logits in return * Fixed OpenAI parameter expectations and logprob structure and datatypes * pre-commit black formatting * Remove unused parameter * fix log probs * fix colorize * nits in server * nits in server * Fix top_logprobs structure (a dict) and include tokens in logprobs response * nits * fix types --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/SERVER.md | 8 +++- llms/mlx_lm/server.py | 99 +++++++++++++++++++++++++++++++------------ llms/mlx_lm/utils.py | 30 ++++++------- 3 files changed, 94 insertions(+), 43 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index aada5f6c..48364bee 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -17,7 +17,7 @@ mlx_lm.server --model For example: ```shell -mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1 +mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit ``` This will start a text generation server on port `8080` of the `localhost` @@ -73,4 +73,8 @@ curl localhost:8080/v1/chat/completions \ applying repetition penalty. Defaults to `20`. - `logit_bias`: (Optional) A dictionary mapping token IDs to their bias - values. Defaults to `None`. \ No newline at end of file + values. Defaults to `None`. + +- `logprobs`: (Optional) An integer specifying the number of top tokens and + corresponding log probabilities to return for each output in the generated + sequence. If set, this can be any value between 1 and 10, inclusive. diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 97a9b40c..19f3f46a 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -6,11 +6,13 @@ import logging import time import uuid import warnings +from functools import lru_cache from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import List, Literal, NamedTuple, Optional, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn +from transformers import PreTrainedTokenizer from .tokenizer_utils import TokenizerWrapper from .utils import generate_step, load @@ -27,18 +29,22 @@ def stopping_criteria( eos_token_id: Union[int, None], ) -> StopCondition: """ - Determines whether the token generation should stop based on predefined conditions. + Determines whether the token generation should stop based on predefined + conditions. Args: tokens (List[int]): The current sequence of generated tokens. - stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs. - If the end of the `tokens` list matches any of these sequences, the generation should stop. - eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this, - the generation should stop. + stop_id_sequences (List[List[[int]]): A list of integer lists, each + representing a sequence of token IDs. If the end of the `tokens` + list matches any of these sequences, the generation should stop. + eos_token_id (Union[int, None]): The token ID that represents the + end-of-sequence. If the last token in `tokens` matches this, the + generation should stop. Returns: - StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`) - and how many tokens should be trimmed from the end if it has (`trim_length`). + StopCondition: A named tuple indicating whether the stop condition has + been met (`stop_met`) and how many tokens should be trimmed from the + end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: return StopCondition(stop_met=True, trim_length=1) @@ -53,7 +59,10 @@ def stopping_criteria( def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): default_role_mapping = { - "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.", + "system_prompt": ( + "A chat between a curious user and an artificial intelligence " + "assistant. The assistant follows the given rules no matter what." + ), "system": "ASSISTANT's RULE: ", "user": "USER: ", "assistant": "ASSISTANT: ", @@ -136,7 +145,7 @@ class APIHandler(BaseHTTPRequestHandler): self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) self.logit_bias = self.body.get("logit_bias", None) - + self.logprobs = self.body.get("logprobs", -1) self.validate_model_parameters() # Get stop id sequences, if provided @@ -184,6 +193,11 @@ class APIHandler(BaseHTTPRequestHandler): ): raise ValueError("repetition_penalty must be a non-negative float") + if self.logprobs != -1 and not (0 < self.logprobs <= 10): + raise ValueError( + f"logprobs must be between 1 and 10 but got {self.logprobs:,}" + ) + if ( not isinstance(self.repetition_context_size, int) or self.repetition_context_size < 0 @@ -208,24 +222,34 @@ class APIHandler(BaseHTTPRequestHandler): finish_reason: Union[Literal["length", "stop"], None], prompt_token_count: Optional[int] = None, completion_token_count: Optional[int] = None, + token_logprobs: Optional[List[float]] = None, + top_tokens: Optional[List[Dict[int, float]]] = None, + tokens: Optional[List[int]] = None, ) -> dict: """ - Generate a single response packet based on response type (stream or not), completion type and parameters. + Generate a single response packet based on response type (stream or + not), completion type and parameters. Args: text (str): Text generated by model - finish_reason (Union[Literal["length", "stop"], None]): - The reason the response is being sent: "length", "stop" or None - prompt_token_count (Optional[int]): - The amount of tokens in the prompt, - used to populate the "usage" field (not used when stream) - completion_token_count (Optional[int]): - The amount of tokens in the response, - used to populate the "usage" field (not used when stream) + finish_reason (Union[Literal["length", "stop"], None]): The reason the + response is being sent: "length", "stop" or `None`. + prompt_token_count (Optional[int]): The number of tokens in the prompt, + used to populate the "usage" field (not used when stream). + completion_token_count (Optional[int]): The number of tokens in the + response, used to populate the "usage" field (not used when stream). + token_logprobs (Optional[List[float]]): The log probabilities per token, + in token order. + top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping + tokens to logprobs for the top N tokens at each token position. + tokens (Optional[List[int]]): List of tokens to return with logprobs structure Returns: - dict: A dictionary containing the response, imitating OpenAI's API + dict: A dictionary containing the response, in the same format as + OpenAI's API. """ + token_logprobs = token_logprobs if token_logprobs else [] + top_logprobs = top_tokens if top_tokens else [] # Static response response = { @@ -237,7 +261,11 @@ class APIHandler(BaseHTTPRequestHandler): "choices": [ { "index": 0, - "logprobs": None, + "logprobs": { + "token_logprobs": token_logprobs, + "top_logprobs": top_logprobs, + "tokens": tokens, + }, "finish_reason": finish_reason, } ], @@ -281,8 +309,8 @@ class APIHandler(BaseHTTPRequestHandler): Args: prompt (mx.array): The prompt, in token form inside of a mlx array - stop_id_sequences (List[List[int]]): - A list of stop words passed to the stopping_criteria function + stop_id_sequences (List[List[int]]): A list of stop words passed + to the stopping_criteria function """ detokenizer = self.tokenizer.detokenizer detokenizer.reset() @@ -290,7 +318,9 @@ class APIHandler(BaseHTTPRequestHandler): finish_reason = "length" stop_sequence_suffix = None logging.debug(f"Starting completion:") - for (token, _), _ in zip( + token_logprobs = [] + top_tokens = [] + for (token, logprobs), _ in zip( generate_step( prompt=prompt, model=self.model, @@ -305,6 +335,16 @@ class APIHandler(BaseHTTPRequestHandler): detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) + + if self.logprobs > 0: + sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1) + top_indices = sorted_indices[: self.logprobs] + top_logprobs = logprobs[top_indices] + top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) + top_tokens.append(dict(top_token_info)) + + token_logprobs.append(logprobs[token].item()) + stop_condition = stopping_criteria( tokens, stop_id_sequences, self.tokenizer.eos_token_id ) @@ -322,7 +362,15 @@ class APIHandler(BaseHTTPRequestHandler): if stop_sequence_suffix is None else detokenizer.text[: -len(stop_sequence_suffix)] ) - response = self.generate_response(text, finish_reason, len(prompt), len(tokens)) + response = self.generate_response( + text, + finish_reason, + len(prompt), + len(tokens), + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + ) response_json = json.dumps(response).encode() indent = "\t" # Backslashes can't be inside of f-strings @@ -458,7 +506,6 @@ class APIHandler(BaseHTTPRequestHandler): assert "prompt" in self.body, "Request did not contain a prompt" prompt_text = self.body["prompt"] - prompt = self.tokenizer.encode(prompt_text) return mx.array(prompt) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d7de95bf..de9c877a 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -149,10 +149,11 @@ def generate_step( consider for repetition penalty. Default: ``20``. top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. + logit_bias (dictionary, optional): Additive logit bias. Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing - one token and probability per call. + Generator[Tuple[mx.array, mx.array], None, None]: A generator producing + one token and a vector of log probabilities. """ def sample(logits: mx.array) -> Tuple[mx.array, float]: @@ -160,7 +161,7 @@ def generate_step( indices = mx.array(list(logit_bias.keys())) values = mx.array(list(logit_bias.values())) logits[:, indices] += values - softmax_logits = mx.softmax(logits) + logprobs = logits - mx.logsumexp(logits) if temp == 0: token = mx.argmax(logits, axis=-1) @@ -170,8 +171,7 @@ def generate_step( else: token = mx.random.categorical(logits * (1 / temp)) - prob = softmax_logits[0, token] - return token, prob + return token, logprobs if repetition_penalty and ( repetition_penalty < 0 or not isinstance(repetition_penalty, float) @@ -202,24 +202,24 @@ def generate_step( logits = apply_repetition_penalty( logits, repetition_context, repetition_penalty ) - y, prob = sample(logits) + y, logprobs = sample(logits) repetition_context.append(y.item()) else: - y, prob = sample(logits) + y, logprobs = sample(logits) if repetition_context_size: if len(repetition_context) > repetition_context_size: repetition_context = repetition_context[-repetition_context_size:] - return y, prob + return y, logprobs.squeeze(0) - y, p = _step(y) + y, logprobs = _step(y) mx.async_eval(y) while True: - next_y, next_p = _step(y) + next_y, next_logprobs = _step(y) mx.async_eval(next_y) - yield y.item(), p - y, p = next_y, next_p + yield y.item(), logprobs + y, logprobs = next_y, next_logprobs def stream_generate( @@ -249,7 +249,7 @@ def stream_generate( detokenizer = tokenizer.detokenizer detokenizer.reset() - for (token, prob), n in zip( + for (token, _), n in zip( generate_step(prompt_tokens, model, **kwargs), range(max_tokens), ): @@ -301,7 +301,7 @@ def generate( tic = time.perf_counter() detokenizer.reset() - for (token, prob), n in zip( + for (token, logprobs), n in zip( generate_step(prompt_tokens, model, **kwargs), range(max_tokens), ): @@ -316,7 +316,7 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - formatter(detokenizer.last_segment, prob.item()) + formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) else: print(detokenizer.last_segment, end="", flush=True)