chore: add /v1/completions for server (#489)

This commit is contained in:
Anchen 2024-02-27 15:59:33 +11:00 committed by GitHub
parent e5dfef5d9a
commit 19a21bfce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import time
import uuid import uuid
from collections import namedtuple from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Optional, Tuple from typing import Callable, List, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -46,7 +46,7 @@ def stopping_criteria(
and how many tokens should be trimmed from the end if it has (`trim_length`). and how many tokens should be trimmed from the end if it has (`trim_length`).
""" """
if tokens and tokens[-1] == eos_token_id: if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=0) return StopCondition(stop_met=True, trim_length=1)
for stop_ids in stop_id_sequences: for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids): if len(tokens) >= len(stop_ids):
@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
def create_response(chat_id, requested_model, prompt, tokens, text): def create_chat_response(chat_id, requested_model, prompt, tokens, text):
response = { response = {
"id": chat_id, "id": chat_id,
"object": "chat.completion", "object": "chat.completion",
@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
return response return response
def create_chunk_response(chat_id, requested_model, next_chunk): def create_completion_response(completion_id, requested_model, prompt, tokens, text):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"choices": [
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(tokens),
"total_tokens": len(prompt) + len(tokens),
},
}
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
response = { response = {
"id": chat_id, "id": chat_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
return response return response
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"choices": [
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
],
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
}
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def _set_headers(self, status_code=200): def _set_headers(self, status_code=200):
self.send_response(status_code) self.send_response(status_code)
@ -186,14 +217,128 @@ class APIHandler(BaseHTTPRequestHandler):
post_data = self.rfile.read(content_length) post_data = self.rfile.read(content_length)
self._set_headers(200) self._set_headers(200)
response = self.handle_post_request(post_data) response = self.handle_chat_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
elif self.path == "/v1/completions":
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
self._set_headers(200)
response = self.handle_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8")) self.wfile.write(json.dumps(response).encode("utf-8"))
else: else:
self._set_headers(404) self._set_headers(404)
self.wfile.write(b"Not Found") self.wfile.write(b"Not Found")
def handle_post_request(self, post_data): def generate_response(
self,
prompt: mx.array,
response_id: str,
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
):
tokens = []
for token, _ in zip(
generate(
prompt,
_model,
temperature,
top_p=top_p,
),
range(max_tokens),
):
tokens.append(token)
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
text = _tokenizer.decode(tokens)
return response_creator(response_id, requested_model, prompt, tokens, text)
def hanlde_stream(
self,
prompt: mx.array,
response_id: str,
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
response_creator: Callable[[str, str, str], dict],
):
self.send_response(200)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
max_stop_id_sequence_len = (
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
)
tokens = []
current_generated_text_index = 0
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(
generate(
prompt,
_model,
temperature,
top_p=top_p,
),
range(max_tokens),
):
tokens.append(token)
stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
if REPLACEMENT_CHAR in _tokenizer.decode(token):
continue
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text)
response = response_creator(response_id, requested_model, next_chunk)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
stop_sequence_buffer = []
except Exception as e:
print(e)
break
# check is there any remaining text to send
if stop_sequence_buffer:
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
response = response_creator(response_id, requested_model, next_chunk)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
except Exception as e:
print(e)
self.wfile.write(f"data: [DONE]\n\n".encode())
self.wfile.flush()
def handle_chat_completions(self, post_data):
body = json.loads(post_data.decode("utf-8")) body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}" chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
@ -223,91 +368,74 @@ class APIHandler(BaseHTTPRequestHandler):
temperature = body.get("temperature", 1.0) temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0) top_p = body.get("top_p", 1.0)
if not stream: if not stream:
tokens = [] return self.generate_response(
for token, _ in zip( prompt,
generate( chat_id,
prompt, requested_model,
_model, stop_id_sequences,
temperature, eos_token_id,
top_p=top_p, max_tokens,
), temperature,
range(max_tokens), top_p,
): create_chat_response,
tokens.append(token) )
stop_condition = stopping_criteria( else:
tokens, stop_id_sequences, eos_token_id self.hanlde_stream(
) prompt,
if stop_condition.stop_met: chat_id,
if stop_condition.trim_length: requested_model,
tokens = tokens[: -stop_condition.trim_length] stop_id_sequences,
break eos_token_id,
max_tokens,
text = _tokenizer.decode(tokens) temperature,
return create_response(chat_id, requested_model, prompt, tokens, text) top_p,
else: create_chat_chunk_response,
self.send_response(200)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
max_stop_id_sequence_len = (
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
) )
tokens = []
current_generated_text_index = 0
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(
generate(
prompt,
_model,
temperature,
top_p=top_p,
),
range(max_tokens),
):
tokens.append(token)
stop_sequence_buffer.append(token)
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
if REPLACEMENT_CHAR in _tokenizer.decode(token):
continue
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text)
response = create_chunk_response( def handle_completions(self, post_data):
chat_id, requested_model, next_chunk body = json.loads(post_data.decode("utf-8"))
) completion_id = f"cmpl-{uuid.uuid4()}"
try: prompt_text = body["prompt"]
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) prompt = _tokenizer.encode(prompt_text, return_tensors="np")
self.wfile.flush() prompt = mx.array(prompt[0])
stop_sequence_buffer = [] stop_words = body.get("stop", [])
except Exception as e: stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
print(e) stop_id_sequences = [
break _tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
# check is there any remaining text to send 0
if stop_sequence_buffer: ]
generated_text = _tokenizer.decode(tokens) for stop_word in stop_words
next_chunk = generated_text[current_generated_text_index:] ]
response = create_chunk_response(chat_id, requested_model, next_chunk) eos_token_id = _tokenizer.eos_token_id
try: max_tokens = body.get("max_tokens", 100)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) stream = body.get("stream", False)
self.wfile.flush() requested_model = body.get("model", "default_model")
except Exception as e: temperature = body.get("temperature", 1.0)
print(e) top_p = body.get("top_p", 1.0)
if not stream:
self.wfile.write(f"data: [DONE]\n\n".encode()) return self.generate_response(
self.wfile.flush() prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_response,
)
else:
self.hanlde_stream(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_chunk_response,
)
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):