mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
chore: add /v1/completions for server (#489)
This commit is contained in:
parent
e5dfef5d9a
commit
19a21bfce4
@ -4,7 +4,7 @@ import time
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -46,7 +46,7 @@ def stopping_criteria(
|
||||
and how many tokens should be trimmed from the end if it has (`trim_length`).
|
||||
"""
|
||||
if tokens and tokens[-1] == eos_token_id:
|
||||
return StopCondition(stop_met=True, trim_length=0)
|
||||
return StopCondition(stop_met=True, trim_length=1)
|
||||
|
||||
for stop_ids in stop_id_sequences:
|
||||
if len(tokens) >= len(stop_ids):
|
||||
@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
def create_response(chat_id, requested_model, prompt, tokens, text):
|
||||
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
|
||||
response = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion",
|
||||
@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
|
||||
return response
|
||||
|
||||
|
||||
def create_chunk_response(chat_id, requested_model, next_chunk):
|
||||
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"choices": [
|
||||
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(prompt),
|
||||
"completion_tokens": len(tokens),
|
||||
"total_tokens": len(prompt) + len(tokens),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
|
||||
response = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
|
||||
return response
|
||||
|
||||
|
||||
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"choices": [
|
||||
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
|
||||
],
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
}
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def _set_headers(self, status_code=200):
|
||||
self.send_response(status_code)
|
||||
@ -186,14 +217,128 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
post_data = self.rfile.read(content_length)
|
||||
self._set_headers(200)
|
||||
|
||||
response = self.handle_post_request(post_data)
|
||||
response = self.handle_chat_completions(post_data)
|
||||
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
elif self.path == "/v1/completions":
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
post_data = self.rfile.read(content_length)
|
||||
self._set_headers(200)
|
||||
|
||||
response = self.handle_completions(post_data)
|
||||
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
else:
|
||||
self._set_headers(404)
|
||||
self.wfile.write(b"Not Found")
|
||||
|
||||
def handle_post_request(self, post_data):
|
||||
def generate_response(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
response_id: str,
|
||||
requested_model: str,
|
||||
stop_id_sequences: List[np.ndarray],
|
||||
eos_token_id: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
|
||||
):
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
generate(
|
||||
prompt,
|
||||
_model,
|
||||
temperature,
|
||||
top_p=top_p,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
tokens.append(token)
|
||||
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
break
|
||||
|
||||
text = _tokenizer.decode(tokens)
|
||||
return response_creator(response_id, requested_model, prompt, tokens, text)
|
||||
|
||||
def hanlde_stream(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
response_id: str,
|
||||
requested_model: str,
|
||||
stop_id_sequences: List[np.ndarray],
|
||||
eos_token_id: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
response_creator: Callable[[str, str, str], dict],
|
||||
):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.end_headers()
|
||||
max_stop_id_sequence_len = (
|
||||
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
|
||||
)
|
||||
tokens = []
|
||||
current_generated_text_index = 0
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
|
||||
stop_sequence_buffer = []
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
for token, _ in zip(
|
||||
generate(
|
||||
prompt,
|
||||
_model,
|
||||
temperature,
|
||||
top_p=top_p,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
tokens.append(token)
|
||||
stop_sequence_buffer.append(token)
|
||||
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
|
||||
if REPLACEMENT_CHAR in _tokenizer.decode(token):
|
||||
continue
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
break
|
||||
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
|
||||
generated_text = _tokenizer.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
current_generated_text_index = len(generated_text)
|
||||
|
||||
response = response_creator(response_id, requested_model, next_chunk)
|
||||
try:
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
stop_sequence_buffer = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
break
|
||||
# check is there any remaining text to send
|
||||
if stop_sequence_buffer:
|
||||
generated_text = _tokenizer.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
response = response_creator(response_id, requested_model, next_chunk)
|
||||
try:
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
self.wfile.write(f"data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
def handle_chat_completions(self, post_data):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
||||
@ -223,91 +368,74 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
if not stream:
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
generate(
|
||||
prompt,
|
||||
_model,
|
||||
temperature,
|
||||
top_p=top_p,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
tokens.append(token)
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, eos_token_id
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
break
|
||||
|
||||
text = _tokenizer.decode(tokens)
|
||||
return create_response(chat_id, requested_model, prompt, tokens, text)
|
||||
else:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.end_headers()
|
||||
max_stop_id_sequence_len = (
|
||||
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
chat_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_chat_response,
|
||||
)
|
||||
else:
|
||||
self.hanlde_stream(
|
||||
prompt,
|
||||
chat_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_chat_chunk_response,
|
||||
)
|
||||
tokens = []
|
||||
current_generated_text_index = 0
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
|
||||
stop_sequence_buffer = []
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
for token, _ in zip(
|
||||
generate(
|
||||
prompt,
|
||||
_model,
|
||||
temperature,
|
||||
top_p=top_p,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
tokens.append(token)
|
||||
stop_sequence_buffer.append(token)
|
||||
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
|
||||
if REPLACEMENT_CHAR in _tokenizer.decode(token):
|
||||
continue
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
break
|
||||
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
|
||||
generated_text = _tokenizer.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
current_generated_text_index = len(generated_text)
|
||||
|
||||
response = create_chunk_response(
|
||||
chat_id, requested_model, next_chunk
|
||||
)
|
||||
try:
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
stop_sequence_buffer = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
break
|
||||
# check is there any remaining text to send
|
||||
if stop_sequence_buffer:
|
||||
generated_text = _tokenizer.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
response = create_chunk_response(chat_id, requested_model, next_chunk)
|
||||
try:
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
self.wfile.write(f"data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
def handle_completions(self, post_data):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
completion_id = f"cmpl-{uuid.uuid4()}"
|
||||
prompt_text = body["prompt"]
|
||||
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
|
||||
prompt = mx.array(prompt[0])
|
||||
stop_words = body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
for stop_word in stop_words
|
||||
]
|
||||
eos_token_id = _tokenizer.eos_token_id
|
||||
max_tokens = body.get("max_tokens", 100)
|
||||
stream = body.get("stream", False)
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
if not stream:
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
completion_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_completion_response,
|
||||
)
|
||||
else:
|
||||
self.hanlde_stream(
|
||||
prompt,
|
||||
completion_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_completion_chunk_response,
|
||||
)
|
||||
|
||||
|
||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||
|
Loading…
Reference in New Issue
Block a user