mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Validate server params & fix logit bias bug (#731)
* Bug fix in logit bias * Add parameter validations * Fix typo * Update docstrings to match MLX styling * Black style + fix a validation bug
This commit is contained in:
parent
7c0962f4e2
commit
4bf2eb17f2
@ -71,3 +71,6 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
- `repetition_context_size`: (Optional) The size of the context window for
|
||||
applying repetition penalty. Defaults to `20`.
|
||||
|
||||
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||
values. Defaults to `None`.
|
@ -104,7 +104,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def do_POST(self):
|
||||
"""
|
||||
Respond to a POST request from a client
|
||||
Respond to a POST request from a client.
|
||||
"""
|
||||
endpoints = {
|
||||
"/v1/completions": self.handle_text_completions,
|
||||
@ -137,6 +137,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||
self.logit_bias = self.body.get("logit_bias", None)
|
||||
|
||||
self.validate_model_parameters()
|
||||
|
||||
# Get stop id sequences, if provided
|
||||
stop_words = self.body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
@ -159,6 +161,46 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
method = self.handle_stream if self.stream else self.handle_completion
|
||||
method(prompt, stop_id_sequences)
|
||||
|
||||
def validate_model_parameters(self):
|
||||
"""
|
||||
Validate the model parameters passed in the request for the correct types and values.
|
||||
"""
|
||||
if not isinstance(self.stream, bool):
|
||||
raise ValueError("stream must be a boolean")
|
||||
|
||||
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
||||
raise ValueError("max_tokens must be a non-negative integer")
|
||||
|
||||
if not isinstance(self.temperature, float) or self.temperature < 0:
|
||||
raise ValueError("temperature must be a non-negative float")
|
||||
|
||||
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
|
||||
raise ValueError("top_p must be a float between 0 and 1")
|
||||
|
||||
if (
|
||||
not isinstance(self.repetition_penalty, float)
|
||||
or self.repetition_penalty < 0
|
||||
):
|
||||
raise ValueError("repetition_penalty must be a non-negative float")
|
||||
|
||||
if (
|
||||
not isinstance(self.repetition_context_size, int)
|
||||
or self.repetition_context_size < 0
|
||||
):
|
||||
raise ValueError("repetition_context_size must be a non-negative integer")
|
||||
|
||||
if self.logit_bias is not None:
|
||||
if not isinstance(self.logit_bias, dict):
|
||||
raise ValueError("logit_bias must be a dict of int to float")
|
||||
|
||||
try:
|
||||
self.logit_bias = {int(k): v for k, v in self.logit_bias.items()}
|
||||
except ValueError:
|
||||
raise ValueError("logit_bias must be a dict of int to float")
|
||||
|
||||
if not isinstance(self.requested_model, str):
|
||||
raise ValueError("model must be a string")
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
text: str,
|
||||
@ -167,8 +209,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
completion_token_count: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a single response packet based on response type (stream or not),
|
||||
completion type and parameters
|
||||
Generate a single response packet based on response type (stream or not), completion type and parameters.
|
||||
|
||||
Args:
|
||||
text (str): Text generated by model
|
||||
@ -235,7 +276,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate a response to a prompt and send it to the client in a single batch
|
||||
Generate a response to a prompt and send it to the client in a single batch.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
@ -299,7 +340,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream
|
||||
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
@ -374,7 +415,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def handle_chat_completions(self) -> mx.array:
|
||||
"""
|
||||
Handle a chat completion request
|
||||
Handle a chat completion request.
|
||||
|
||||
Returns:
|
||||
mx.array: A mx.array of the tokenized prompt from the request body
|
||||
@ -405,7 +446,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def handle_text_completions(self) -> mx.array:
|
||||
"""
|
||||
Handle a text completion request
|
||||
Handle a text completion request.
|
||||
|
||||
Returns:
|
||||
mx.array: A mx.array of the tokenized prompt from the request body
|
||||
|
@ -136,7 +136,10 @@ def generate_step(
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
logits = logits + logit_bias if logit_bias else logits
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
softmax_logits = mx.softmax(logits)
|
||||
|
||||
if temp == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user