mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Validate server params & fix logit bias bug (#731)
* Bug fix in logit bias * Add parameter validations * Fix typo * Update docstrings to match MLX styling * Black style + fix a validation bug
This commit is contained in:
@@ -71,3 +71,6 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
|
|
||||||
- `repetition_context_size`: (Optional) The size of the context window for
|
- `repetition_context_size`: (Optional) The size of the context window for
|
||||||
applying repetition penalty. Defaults to `20`.
|
applying repetition penalty. Defaults to `20`.
|
||||||
|
|
||||||
|
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||||
|
values. Defaults to `None`.
|
@@ -104,7 +104,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
"""
|
"""
|
||||||
Respond to a POST request from a client
|
Respond to a POST request from a client.
|
||||||
"""
|
"""
|
||||||
endpoints = {
|
endpoints = {
|
||||||
"/v1/completions": self.handle_text_completions,
|
"/v1/completions": self.handle_text_completions,
|
||||||
@@ -137,6 +137,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||||
self.logit_bias = self.body.get("logit_bias", None)
|
self.logit_bias = self.body.get("logit_bias", None)
|
||||||
|
|
||||||
|
self.validate_model_parameters()
|
||||||
|
|
||||||
# Get stop id sequences, if provided
|
# Get stop id sequences, if provided
|
||||||
stop_words = self.body.get("stop", [])
|
stop_words = self.body.get("stop", [])
|
||||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||||
@@ -159,6 +161,46 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
method = self.handle_stream if self.stream else self.handle_completion
|
method = self.handle_stream if self.stream else self.handle_completion
|
||||||
method(prompt, stop_id_sequences)
|
method(prompt, stop_id_sequences)
|
||||||
|
|
||||||
|
def validate_model_parameters(self):
|
||||||
|
"""
|
||||||
|
Validate the model parameters passed in the request for the correct types and values.
|
||||||
|
"""
|
||||||
|
if not isinstance(self.stream, bool):
|
||||||
|
raise ValueError("stream must be a boolean")
|
||||||
|
|
||||||
|
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
||||||
|
raise ValueError("max_tokens must be a non-negative integer")
|
||||||
|
|
||||||
|
if not isinstance(self.temperature, float) or self.temperature < 0:
|
||||||
|
raise ValueError("temperature must be a non-negative float")
|
||||||
|
|
||||||
|
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
|
||||||
|
raise ValueError("top_p must be a float between 0 and 1")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not isinstance(self.repetition_penalty, float)
|
||||||
|
or self.repetition_penalty < 0
|
||||||
|
):
|
||||||
|
raise ValueError("repetition_penalty must be a non-negative float")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not isinstance(self.repetition_context_size, int)
|
||||||
|
or self.repetition_context_size < 0
|
||||||
|
):
|
||||||
|
raise ValueError("repetition_context_size must be a non-negative integer")
|
||||||
|
|
||||||
|
if self.logit_bias is not None:
|
||||||
|
if not isinstance(self.logit_bias, dict):
|
||||||
|
raise ValueError("logit_bias must be a dict of int to float")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logit_bias = {int(k): v for k, v in self.logit_bias.items()}
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("logit_bias must be a dict of int to float")
|
||||||
|
|
||||||
|
if not isinstance(self.requested_model, str):
|
||||||
|
raise ValueError("model must be a string")
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -167,8 +209,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
completion_token_count: Optional[int] = None,
|
completion_token_count: Optional[int] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Generate a single response packet based on response type (stream or not),
|
Generate a single response packet based on response type (stream or not), completion type and parameters.
|
||||||
completion type and parameters
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): Text generated by model
|
text (str): Text generated by model
|
||||||
@@ -235,7 +276,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_id_sequences: List[List[int]],
|
stop_id_sequences: List[List[int]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate a response to a prompt and send it to the client in a single batch
|
Generate a response to a prompt and send it to the client in a single batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||||
@@ -299,7 +340,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_id_sequences: List[List[int]],
|
stop_id_sequences: List[List[int]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream
|
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||||
@@ -374,7 +415,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def handle_chat_completions(self) -> mx.array:
|
def handle_chat_completions(self) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Handle a chat completion request
|
Handle a chat completion request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mx.array: A mx.array of the tokenized prompt from the request body
|
mx.array: A mx.array of the tokenized prompt from the request body
|
||||||
@@ -405,7 +446,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def handle_text_completions(self) -> mx.array:
|
def handle_text_completions(self) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Handle a text completion request
|
Handle a text completion request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mx.array: A mx.array of the tokenized prompt from the request body
|
mx.array: A mx.array of the tokenized prompt from the request body
|
||||||
|
@@ -136,7 +136,10 @@ def generate_step(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||||
logits = logits + logit_bias if logit_bias else logits
|
if logit_bias:
|
||||||
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
|
values = mx.array(list(logit_bias.values()))
|
||||||
|
logits[:, indices] += values
|
||||||
softmax_logits = mx.softmax(logits)
|
softmax_logits = mx.softmax(logits)
|
||||||
|
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
|
Reference in New Issue
Block a user