Validate server params & fix logit bias bug (#731)

* Bug fix in logit bias

* Add parameter validations

* Fix typo

* Update docstrings to match MLX styling

* Black style + fix a validation bug
This commit is contained in:
Karim Elmaaroufi 2024-04-30 07:27:40 -07:00 committed by GitHub
parent 7c0962f4e2
commit 4bf2eb17f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 8 deletions

View File

@ -71,3 +71,6 @@ curl localhost:8080/v1/chat/completions \
- `repetition_context_size`: (Optional) The size of the context window for
applying repetition penalty. Defaults to `20`.
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
values. Defaults to `None`.

View File

@ -104,7 +104,7 @@ class APIHandler(BaseHTTPRequestHandler):
def do_POST(self):
"""
Respond to a POST request from a client
Respond to a POST request from a client.
"""
endpoints = {
"/v1/completions": self.handle_text_completions,
@ -137,6 +137,8 @@ class APIHandler(BaseHTTPRequestHandler):
self.repetition_context_size = self.body.get("repetition_context_size", 20)
self.logit_bias = self.body.get("logit_bias", None)
self.validate_model_parameters()
# Get stop id sequences, if provided
stop_words = self.body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
@ -159,6 +161,46 @@ class APIHandler(BaseHTTPRequestHandler):
method = self.handle_stream if self.stream else self.handle_completion
method(prompt, stop_id_sequences)
def validate_model_parameters(self):
"""
Validate the model parameters passed in the request for the correct types and values.
"""
if not isinstance(self.stream, bool):
raise ValueError("stream must be a boolean")
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
raise ValueError("max_tokens must be a non-negative integer")
if not isinstance(self.temperature, float) or self.temperature < 0:
raise ValueError("temperature must be a non-negative float")
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
raise ValueError("top_p must be a float between 0 and 1")
if (
not isinstance(self.repetition_penalty, float)
or self.repetition_penalty < 0
):
raise ValueError("repetition_penalty must be a non-negative float")
if (
not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0
):
raise ValueError("repetition_context_size must be a non-negative integer")
if self.logit_bias is not None:
if not isinstance(self.logit_bias, dict):
raise ValueError("logit_bias must be a dict of int to float")
try:
self.logit_bias = {int(k): v for k, v in self.logit_bias.items()}
except ValueError:
raise ValueError("logit_bias must be a dict of int to float")
if not isinstance(self.requested_model, str):
raise ValueError("model must be a string")
def generate_response(
self,
text: str,
@ -167,8 +209,7 @@ class APIHandler(BaseHTTPRequestHandler):
completion_token_count: Optional[int] = None,
) -> dict:
"""
Generate a single response packet based on response type (stream or not),
completion type and parameters
Generate a single response packet based on response type (stream or not), completion type and parameters.
Args:
text (str): Text generated by model
@ -235,7 +276,7 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
@ -299,7 +340,7 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
@ -374,7 +415,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_chat_completions(self) -> mx.array:
"""
Handle a chat completion request
Handle a chat completion request.
Returns:
mx.array: A mx.array of the tokenized prompt from the request body
@ -405,7 +446,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_text_completions(self) -> mx.array:
"""
Handle a text completion request
Handle a text completion request.
Returns:
mx.array: A mx.array of the tokenized prompt from the request body

View File

@ -136,7 +136,10 @@ def generate_step(
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logits = logits + logit_bias if logit_bias else logits
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
softmax_logits = mx.softmax(logits)
if temp == 0: