diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 23d327e5..c13878f3 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -6,16 +6,12 @@ import logging import time import uuid import warnings -from functools import lru_cache from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Union import mlx.core as mx -import mlx.nn as nn -from transformers import PreTrainedTokenizer -from .tokenizer_utils import TokenizerWrapper from .utils import generate_step, load @@ -195,6 +191,7 @@ class APIHandler(BaseHTTPRequestHandler): # Extract request parameters from the body self.stream = self.body.get("stream", False) + self.stream_options = self.body.get("stream_options", None) self.requested_model = self.body.get("model", "default_model") self.max_tokens = self.body.get("max_tokens", 100) self.temperature = self.body.get("temperature", 1.0) @@ -525,9 +522,33 @@ class APIHandler(BaseHTTPRequestHandler): self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() + if self.stream_options is not None and self.stream_options["include_usage"]: + response = self.completion_usage_response(len(prompt), len(tokens)) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.write("data: [DONE]\n\n".encode()) self.wfile.flush() + def completion_usage_response( + self, + prompt_token_count: Optional[int] = None, + completion_token_count: Optional[int] = None, + ): + response = { + "id": self.request_id, + "system_fingerprint": f"fp_{uuid.uuid4()}", + "object": "chat.completion", + "model": self.requested_model, + "created": self.created, + "choices": [], + "usage": { + "prompt_tokens": prompt_token_count, + "completion_tokens": completion_token_count, + "total_tokens": prompt_token_count + completion_token_count, + }, + } + return response + def handle_chat_completions(self) -> mx.array: """ Handle a chat completion request.