mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Server: support stream_options (#913)
* Server: support stream_options see https://x.com/OpenAIDevs/status/1787573348496773423 * Server: support stream_options * Server: check None type
This commit is contained in:
parent
46da74fea2
commit
85dc76f6e0
@ -6,16 +6,12 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from functools import lru_cache
|
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
from typing import Dict, List, Literal, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from .tokenizer_utils import TokenizerWrapper
|
|
||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
|
|
||||||
|
|
||||||
@ -195,6 +191,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# Extract request parameters from the body
|
# Extract request parameters from the body
|
||||||
self.stream = self.body.get("stream", False)
|
self.stream = self.body.get("stream", False)
|
||||||
|
self.stream_options = self.body.get("stream_options", None)
|
||||||
self.requested_model = self.body.get("model", "default_model")
|
self.requested_model = self.body.get("model", "default_model")
|
||||||
self.max_tokens = self.body.get("max_tokens", 100)
|
self.max_tokens = self.body.get("max_tokens", 100)
|
||||||
self.temperature = self.body.get("temperature", 1.0)
|
self.temperature = self.body.get("temperature", 1.0)
|
||||||
@ -525,9 +522,33 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
|
if self.stream_options is not None and self.stream_options["include_usage"]:
|
||||||
|
response = self.completion_usage_response(len(prompt), len(tokens))
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
|
||||||
self.wfile.write("data: [DONE]\n\n".encode())
|
self.wfile.write("data: [DONE]\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
|
def completion_usage_response(
|
||||||
|
self,
|
||||||
|
prompt_token_count: Optional[int] = None,
|
||||||
|
completion_token_count: Optional[int] = None,
|
||||||
|
):
|
||||||
|
response = {
|
||||||
|
"id": self.request_id,
|
||||||
|
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": self.requested_model,
|
||||||
|
"created": self.created,
|
||||||
|
"choices": [],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": prompt_token_count + completion_token_count,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
|
||||||
def handle_chat_completions(self) -> mx.array:
|
def handle_chat_completions(self) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Handle a chat completion request.
|
Handle a chat completion request.
|
||||||
|
Loading…
Reference in New Issue
Block a user