mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Server: support stream_options (#913)
* Server: support stream_options see https://x.com/OpenAIDevs/status/1787573348496773423 * Server: support stream_options * Server: check None type
This commit is contained in:
parent
46da74fea2
commit
85dc76f6e0
@ -6,16 +6,12 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .utils import generate_step, load
|
||||
|
||||
|
||||
@ -195,6 +191,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
# Extract request parameters from the body
|
||||
self.stream = self.body.get("stream", False)
|
||||
self.stream_options = self.body.get("stream_options", None)
|
||||
self.requested_model = self.body.get("model", "default_model")
|
||||
self.max_tokens = self.body.get("max_tokens", 100)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
@ -525,9 +522,33 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
if self.stream_options is not None and self.stream_options["include_usage"]:
|
||||
response = self.completion_usage_response(len(prompt), len(tokens))
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
|
||||
self.wfile.write("data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
def completion_usage_response(
|
||||
self,
|
||||
prompt_token_count: Optional[int] = None,
|
||||
completion_token_count: Optional[int] = None,
|
||||
):
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
"choices": [],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": prompt_token_count + completion_token_count,
|
||||
},
|
||||
}
|
||||
return response
|
||||
|
||||
def handle_chat_completions(self) -> mx.array:
|
||||
"""
|
||||
Handle a chat completion request.
|
||||
|
Loading…
Reference in New Issue
Block a user