chore: add /v1/completions for server (#489)

This commit is contained in:
Anchen 2024-02-27 15:59:33 +11:00 committed by GitHub
parent e5dfef5d9a
commit 19a21bfce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import time
import uuid
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Optional, Tuple
from typing import Callable, List, Optional
import mlx.core as mx
import mlx.nn as nn
@ -46,7 +46,7 @@ def stopping_criteria(
and how many tokens should be trimmed from the end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=0)
return StopCondition(stop_met=True, trim_length=1)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
return prompt.rstrip()
def create_response(chat_id, requested_model, prompt, tokens, text):
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
response = {
"id": chat_id,
"object": "chat.completion",
@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
return response
def create_chunk_response(chat_id, requested_model, next_chunk):
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"choices": [
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(tokens),
"total_tokens": len(prompt) + len(tokens),
},
}
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
response = {
"id": chat_id,
"object": "chat.completion.chunk",
@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
return response
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"choices": [
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
],
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
}
class APIHandler(BaseHTTPRequestHandler):
def _set_headers(self, status_code=200):
self.send_response(status_code)
@ -186,14 +217,128 @@ class APIHandler(BaseHTTPRequestHandler):
post_data = self.rfile.read(content_length)
self._set_headers(200)
response = self.handle_post_request(post_data)
response = self.handle_chat_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
elif self.path == "/v1/completions":
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
self._set_headers(200)
response = self.handle_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
else:
self._set_headers(404)
self.wfile.write(b"Not Found")
def handle_post_request(self, post_data):
def generate_response(
self,
prompt: mx.array,
response_id: str,
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
):
tokens = []
for token, _ in zip(
generate(
prompt,
_model,
temperature,
top_p=top_p,
),
range(max_tokens),
):
tokens.append(token)
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
text = _tokenizer.decode(tokens)
return response_creator(response_id, requested_model, prompt, tokens, text)
def hanlde_stream(
self,
prompt: mx.array,
response_id: str,
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
response_creator: Callable[[str, str, str], dict],
):
self.send_response(200)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
max_stop_id_sequence_len = (
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
)
tokens = []
current_generated_text_index = 0
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(
generate(
prompt,
_model,
temperature,
top_p=top_p,
),
range(max_tokens),
):
tokens.append(token)
stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
if REPLACEMENT_CHAR in _tokenizer.decode(token):
continue
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text)
response = response_creator(response_id, requested_model, next_chunk)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
stop_sequence_buffer = []
except Exception as e:
print(e)
break
# check is there any remaining text to send
if stop_sequence_buffer:
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
response = response_creator(response_id, requested_model, next_chunk)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
except Exception as e:
print(e)
self.wfile.write(f"data: [DONE]\n\n".encode())
self.wfile.flush()
def handle_chat_completions(self, post_data):
body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
@ -223,91 +368,74 @@ class APIHandler(BaseHTTPRequestHandler):
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
tokens = []
for token, _ in zip(
generate(
prompt,
_model,
temperature,
top_p=top_p,
),
range(max_tokens),
):
tokens.append(token)
stop_condition = stopping_criteria(
tokens, stop_id_sequences, eos_token_id
)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
text = _tokenizer.decode(tokens)
return create_response(chat_id, requested_model, prompt, tokens, text)
else:
self.send_response(200)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
max_stop_id_sequence_len = (
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
return self.generate_response(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_chat_response,
)
else:
self.hanlde_stream(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_chat_chunk_response,
)
tokens = []
current_generated_text_index = 0
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(
generate(
prompt,
_model,
temperature,
top_p=top_p,
),
range(max_tokens),
):
tokens.append(token)
stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
if REPLACEMENT_CHAR in _tokenizer.decode(token):
continue
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text)
response = create_chunk_response(
chat_id, requested_model, next_chunk
)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
stop_sequence_buffer = []
except Exception as e:
print(e)
break
# check is there any remaining text to send
if stop_sequence_buffer:
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
response = create_chunk_response(chat_id, requested_model, next_chunk)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
except Exception as e:
print(e)
self.wfile.write(f"data: [DONE]\n\n".encode())
self.wfile.flush()
def handle_completions(self, post_data):
body = json.loads(post_data.decode("utf-8"))
completion_id = f"cmpl-{uuid.uuid4()}"
prompt_text = body["prompt"]
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
prompt = mx.array(prompt[0])
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
return self.generate_response(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_response,
)
else:
self.hanlde_stream(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_chunk_response,
)
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):