chore: add /v1/completions for server (#489)

This commit is contained in:
Anchen 2024-02-27 15:59:33 +11:00 committed by GitHub
parent e5dfef5d9a
commit 19a21bfce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import time
import uuid
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Optional, Tuple
from typing import Callable, List, Optional
import mlx.core as mx
import mlx.nn as nn
@ -46,7 +46,7 @@ def stopping_criteria(
and how many tokens should be trimmed from the end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=0)
return StopCondition(stop_met=True, trim_length=1)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
return prompt.rstrip()
def create_response(chat_id, requested_model, prompt, tokens, text):
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
response = {
"id": chat_id,
"object": "chat.completion",
@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
return response
def create_chunk_response(chat_id, requested_model, next_chunk):
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"choices": [
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(tokens),
"total_tokens": len(prompt) + len(tokens),
},
}
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
response = {
"id": chat_id,
"object": "chat.completion.chunk",
@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
return response
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
return {
"id": completion_id,
"object": "text_completion",
"created": int(time.time()),
"choices": [
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
],
"model": requested_model,
"system_fingerprint": f"fp_{uuid.uuid4()}",
}
class APIHandler(BaseHTTPRequestHandler):
def _set_headers(self, status_code=200):
self.send_response(status_code)
@ -186,43 +217,33 @@ class APIHandler(BaseHTTPRequestHandler):
post_data = self.rfile.read(content_length)
self._set_headers(200)
response = self.handle_post_request(post_data)
response = self.handle_chat_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
elif self.path == "/v1/completions":
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
self._set_headers(200)
response = self.handle_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
else:
self._set_headers(404)
self.wfile.write(b"Not Found")
def handle_post_request(self, post_data):
body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
prompt = _tokenizer.apply_chat_template(
body["messages"],
tokenize=True,
add_generation_prompt=True,
return_tensors="np",
)
else:
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = _tokenizer.encode(prompt, return_tensors="np")
prompt = mx.array(prompt[0])
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
def generate_response(
self,
prompt: mx.array,
response_id: str,
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
):
tokens = []
for token, _ in zip(
generate(
@ -234,17 +255,27 @@ class APIHandler(BaseHTTPRequestHandler):
range(max_tokens),
):
tokens.append(token)
stop_condition = stopping_criteria(
tokens, stop_id_sequences, eos_token_id
)
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
if stop_condition.stop_met:
if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length]
break
text = _tokenizer.decode(tokens)
return create_response(chat_id, requested_model, prompt, tokens, text)
else:
return response_creator(response_id, requested_model, prompt, tokens, text)
def hanlde_stream(
self,
prompt: mx.array,
response_id: str,
requested_model: str,
stop_id_sequences: List[np.ndarray],
eos_token_id: int,
max_tokens: int,
temperature: float,
top_p: float,
response_creator: Callable[[str, str, str], dict],
):
self.send_response(200)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
@ -285,9 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
next_chunk = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text)
response = create_chunk_response(
chat_id, requested_model, next_chunk
)
response = response_creator(response_id, requested_model, next_chunk)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
@ -299,7 +328,7 @@ class APIHandler(BaseHTTPRequestHandler):
if stop_sequence_buffer:
generated_text = _tokenizer.decode(tokens)
next_chunk = generated_text[current_generated_text_index:]
response = create_chunk_response(chat_id, requested_model, next_chunk)
response = response_creator(response_id, requested_model, next_chunk)
try:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
@ -309,6 +338,105 @@ class APIHandler(BaseHTTPRequestHandler):
self.wfile.write(f"data: [DONE]\n\n".encode())
self.wfile.flush()
def handle_chat_completions(self, post_data):
body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
prompt = _tokenizer.apply_chat_template(
body["messages"],
tokenize=True,
add_generation_prompt=True,
return_tensors="np",
)
else:
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = _tokenizer.encode(prompt, return_tensors="np")
prompt = mx.array(prompt[0])
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
return self.generate_response(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_chat_response,
)
else:
self.hanlde_stream(
prompt,
chat_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_chat_chunk_response,
)
def handle_completions(self, post_data):
body = json.loads(post_data.decode("utf-8"))
completion_id = f"cmpl-{uuid.uuid4()}"
prompt_text = body["prompt"]
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
prompt = mx.array(prompt[0])
stop_words = body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
0
]
for stop_word in stop_words
]
eos_token_id = _tokenizer.eos_token_id
max_tokens = body.get("max_tokens", 100)
stream = body.get("stream", False)
requested_model = body.get("model", "default_model")
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
return self.generate_response(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_response,
)
else:
self.hanlde_stream(
prompt,
completion_id,
requested_model,
stop_id_sequences,
eos_token_id,
max_tokens,
temperature,
top_p,
create_completion_chunk_response,
)
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
server_address = (host, port)