chore: add /v1/completions for server (#489)

This commit is contained in:
Anchen 2024-02-27 15:59:33 +11:00 committed by GitHub
parent e5dfef5d9a
commit 19a21bfce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import time
import uuid import uuid
from collections import namedtuple from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Optional, Tuple from typing import Callable, List, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -46,7 +46,7 @@ def stopping_criteria(
and how many tokens should be trimmed from the end if it has (`trim_length`). and how many tokens should be trimmed from the end if it has (`trim_length`).
""" """
if tokens and tokens[-1] == eos_token_id: if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=0) return StopCondition(stop_met=True, trim_length=1)
for stop_ids in stop_id_sequences: for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids): if len(tokens) >= len(stop_ids):
@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
def create_response(chat_id, requested_model, prompt, tokens, text): def create_chat_response(chat_id, requested_model, prompt, tokens, text):
response = { response = {
"id": chat_id, "id": chat_id,
"object": "chat.completion", "object": "chat.completion",
@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
return response return response
def create_chunk_response(chat_id, requested_model, next_chunk): def create_completion_response(completion_id, requested_model, prompt, tokens, text):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"choices": [
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(tokens),
"total_tokens": len(prompt) + len(tokens),
},
}
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
response = { response = {
"id": chat_id, "id": chat_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
return response return response
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"choices": [
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
],
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
}
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def _set_headers(self, status_code=200): def _set_headers(self, status_code=200):
self.send_response(status_code) self.send_response(status_code)
@ -186,43 +217,33 @@ class APIHandler(BaseHTTPRequestHandler):
post_data = self.rfile.read(content_length) post_data = self.rfile.read(content_length)
self._set_headers(200) self._set_headers(200)
response = self.handle_post_request(post_data) response = self.handle_chat_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
elif self.path == "/v1/completions":
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
self._set_headers(200)
response = self.handle_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8")) self.wfile.write(json.dumps(response).encode("utf-8"))
else: else:
self._set_headers(404) self._set_headers(404)
self.wfile.write(b"Not Found") self.wfile.write(b"Not Found")
def handle_post_request(self, post_data): def generate_response(
body = json.loads(post_data.decode("utf-8")) self,
chat_id = f"chatcmpl-{uuid.uuid4()}" prompt: mx.array,
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: response_id: str,
prompt = _tokenizer.apply_chat_template( requested_model: str,
body["messages"], stop_id_sequences: List[np.ndarray],
tokenize=True, eos_token_id: int,
add_generation_prompt=True, max_tokens: int,
return_tensors="np", temperature: float,
) top_p: float,
else: response_creator: Callable[[str, str, mx.array, List[int], str], dict],
prompt = convert_chat(body["messages"], body.get("role_mapping")) ):
prompt = _tokenizer.encode(prompt, return_tensors="np")
prompt = mx.array(prompt[0])
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
tokens = [] tokens = []
for token, _ in zip( for token, _ in zip(
generate( generate(
@ -234,17 +255,27 @@ class APIHandler(BaseHTTPRequestHandler):
range(max_tokens), range(max_tokens),
): ):
tokens.append(token) tokens.append(token)
stop_condition = stopping_criteria( stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
tokens, stop_id_sequences, eos_token_id
)
if stop_condition.stop_met: if stop_condition.stop_met:
if stop_condition.trim_length: if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length] tokens = tokens[: -stop_condition.trim_length]
break break
text = _tokenizer.decode(tokens) text = _tokenizer.decode(tokens)
return create_response(chat_id, requested_model, prompt, tokens, text) return response_creator(response_id, requested_model, prompt, tokens, text)
else:
def hanlde_stream(
self,
prompt: mx.array,
response_id: str,
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
response_creator: Callable[[str, str, str], dict],
):
self.send_response(200) self.send_response(200)
self.send_header("Content-type", "text/event-stream") self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache") self.send_header("Cache-Control", "no-cache")
@ -285,9 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
next_chunk = generated_text[current_generated_text_index:] next_chunk = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text) current_generated_text_index = len(generated_text)
response = create_chunk_response( response = response_creator(response_id, requested_model, next_chunk)
chat_id, requested_model, next_chunk
)
try: try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush() self.wfile.flush()
@ -299,7 +328,7 @@ class APIHandler(BaseHTTPRequestHandler):
if stop_sequence_buffer: if stop_sequence_buffer:
generated_text = _tokenizer.decode(tokens) generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:] next_chunk = generated_text[current_generated_text_index:]
response = create_chunk_response(chat_id, requested_model, next_chunk) response = response_creator(response_id, requested_model, next_chunk)
try: try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush() self.wfile.flush()
@ -309,6 +338,105 @@ class APIHandler(BaseHTTPRequestHandler):
self.wfile.write(f"data: [DONE]\n\n".encode()) self.wfile.write(f"data: [DONE]\n\n".encode())
self.wfile.flush() self.wfile.flush()
def handle_chat_completions(self, post_data):
body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
prompt = _tokenizer.apply_chat_template(
body["messages"],
tokenize=True,
add_generation_prompt=True,
return_tensors="np",
)
else:
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = _tokenizer.encode(prompt, return_tensors="np")
prompt = mx.array(prompt[0])
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
return self.generate_response(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_chat_response,
)
else:
self.hanlde_stream(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_chat_chunk_response,
)
def handle_completions(self, post_data):
body = json.loads(post_data.decode("utf-8"))
completion_id = f"cmpl-{uuid.uuid4()}"
prompt_text = body["prompt"]
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
prompt = mx.array(prompt[0])
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
return self.generate_response(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_response,
)
else:
self.hanlde_stream(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_chunk_response,
)
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
server_address = (host, port) server_address = (host, port)