Server: support stream_options (#913)

* Server: support stream_options

see https://x.com/OpenAIDevs/status/1787573348496773423

* Server: support stream_options

* Server: check None type
This commit is contained in:
madroid 2024-07-26 23:58:52 +08:00 committed by GitHub
parent 46da74fea2
commit 85dc76f6e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,16 +6,12 @@ import logging
import time
import uuid
import warnings
from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
from typing import Dict, List, Literal, NamedTuple, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from transformers import PreTrainedTokenizer
from .tokenizer_utils import TokenizerWrapper
from .utils import generate_step, load
@ -195,6 +191,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Extract request parameters from the body
self.stream = self.body.get("stream", False)
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.max_tokens = self.body.get("max_tokens", 100)
self.temperature = self.body.get("temperature", 1.0)
@ -525,9 +522,33 @@ class APIHandler(BaseHTTPRequestHandler):
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
def completion_usage_response(
self,
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
):
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
"choices": [],
"usage": {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count,
},
}
return response
def handle_chat_completions(self) -> mx.array:
"""
Handle a chat completion request.