From 4bf2eb17f2f68e107219e853e51d921f263aabae Mon Sep 17 00:00:00 2001 From: Karim Elmaaroufi Date: Tue, 30 Apr 2024 07:27:40 -0700 Subject: [PATCH] Validate server params & fix logit bias bug (#731) * Bug fix in logit bias * Add parameter validations * Fix typo * Update docstrings to match MLX styling * Black style + fix a validation bug --- llms/mlx_lm/SERVER.md | 3 +++ llms/mlx_lm/server.py | 55 +++++++++++++++++++++++++++++++++++++------ llms/mlx_lm/utils.py | 5 +++- 3 files changed, 55 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index edea5457..aada5f6c 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -71,3 +71,6 @@ curl localhost:8080/v1/chat/completions \ - `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`. + +- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias + values. Defaults to `None`. \ No newline at end of file diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ffbe7556..868f7a2f 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -104,7 +104,7 @@ class APIHandler(BaseHTTPRequestHandler): def do_POST(self): """ - Respond to a POST request from a client + Respond to a POST request from a client. """ endpoints = { "/v1/completions": self.handle_text_completions, @@ -137,6 +137,8 @@ class APIHandler(BaseHTTPRequestHandler): self.repetition_context_size = self.body.get("repetition_context_size", 20) self.logit_bias = self.body.get("logit_bias", None) + self.validate_model_parameters() + # Get stop id sequences, if provided stop_words = self.body.get("stop", []) stop_words = [stop_words] if isinstance(stop_words, str) else stop_words @@ -159,6 +161,46 @@ class APIHandler(BaseHTTPRequestHandler): method = self.handle_stream if self.stream else self.handle_completion method(prompt, stop_id_sequences) + def validate_model_parameters(self): + """ + Validate the model parameters passed in the request for the correct types and values. + """ + if not isinstance(self.stream, bool): + raise ValueError("stream must be a boolean") + + if not isinstance(self.max_tokens, int) or self.max_tokens < 0: + raise ValueError("max_tokens must be a non-negative integer") + + if not isinstance(self.temperature, float) or self.temperature < 0: + raise ValueError("temperature must be a non-negative float") + + if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1: + raise ValueError("top_p must be a float between 0 and 1") + + if ( + not isinstance(self.repetition_penalty, float) + or self.repetition_penalty < 0 + ): + raise ValueError("repetition_penalty must be a non-negative float") + + if ( + not isinstance(self.repetition_context_size, int) + or self.repetition_context_size < 0 + ): + raise ValueError("repetition_context_size must be a non-negative integer") + + if self.logit_bias is not None: + if not isinstance(self.logit_bias, dict): + raise ValueError("logit_bias must be a dict of int to float") + + try: + self.logit_bias = {int(k): v for k, v in self.logit_bias.items()} + except ValueError: + raise ValueError("logit_bias must be a dict of int to float") + + if not isinstance(self.requested_model, str): + raise ValueError("model must be a string") + def generate_response( self, text: str, @@ -167,8 +209,7 @@ class APIHandler(BaseHTTPRequestHandler): completion_token_count: Optional[int] = None, ) -> dict: """ - Generate a single response packet based on response type (stream or not), - completion type and parameters + Generate a single response packet based on response type (stream or not), completion type and parameters. Args: text (str): Text generated by model @@ -235,7 +276,7 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences: List[List[int]], ): """ - Generate a response to a prompt and send it to the client in a single batch + Generate a response to a prompt and send it to the client in a single batch. Args: prompt (mx.array): The prompt, in token form inside of a mlx array @@ -299,7 +340,7 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences: List[List[int]], ): """ - Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream + Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream. Args: prompt (mx.array): The prompt, in token form inside of a mlx array @@ -374,7 +415,7 @@ class APIHandler(BaseHTTPRequestHandler): def handle_chat_completions(self) -> mx.array: """ - Handle a chat completion request + Handle a chat completion request. Returns: mx.array: A mx.array of the tokenized prompt from the request body @@ -405,7 +446,7 @@ class APIHandler(BaseHTTPRequestHandler): def handle_text_completions(self) -> mx.array: """ - Handle a text completion request + Handle a text completion request. Returns: mx.array: A mx.array of the tokenized prompt from the request body diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 8516e38c..67b3dfce 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -136,7 +136,10 @@ def generate_step( """ def sample(logits: mx.array) -> Tuple[mx.array, float]: - logits = logits + logit_bias if logit_bias else logits + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + logits[:, indices] += values softmax_logits = mx.softmax(logits) if temp == 0: