mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
chore: add /v1/completions for server (#489)
This commit is contained in:
parent
e5dfef5d9a
commit
19a21bfce4
@ -4,7 +4,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from typing import List, Optional, Tuple
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -46,7 +46,7 @@ def stopping_criteria(
|
|||||||
and how many tokens should be trimmed from the end if it has (`trim_length`).
|
and how many tokens should be trimmed from the end if it has (`trim_length`).
|
||||||
"""
|
"""
|
||||||
if tokens and tokens[-1] == eos_token_id:
|
if tokens and tokens[-1] == eos_token_id:
|
||||||
return StopCondition(stop_met=True, trim_length=0)
|
return StopCondition(stop_met=True, trim_length=1)
|
||||||
|
|
||||||
for stop_ids in stop_id_sequences:
|
for stop_ids in stop_id_sequences:
|
||||||
if len(tokens) >= len(stop_ids):
|
if len(tokens) >= len(stop_ids):
|
||||||
@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
|||||||
return prompt.rstrip()
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
def create_response(chat_id, requested_model, prompt, tokens, text):
|
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
|
||||||
response = {
|
response = {
|
||||||
"id": chat_id,
|
"id": chat_id,
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def create_chunk_response(chat_id, requested_model, next_chunk):
|
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
|
||||||
|
return {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": requested_model,
|
||||||
|
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||||
|
"choices": [
|
||||||
|
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": len(prompt),
|
||||||
|
"completion_tokens": len(tokens),
|
||||||
|
"total_tokens": len(prompt) + len(tokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
|
||||||
response = {
|
response = {
|
||||||
"id": chat_id,
|
"id": chat_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
|
||||||
|
return {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"choices": [
|
||||||
|
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
|
||||||
|
],
|
||||||
|
"model": requested_model,
|
||||||
|
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class APIHandler(BaseHTTPRequestHandler):
|
class APIHandler(BaseHTTPRequestHandler):
|
||||||
def _set_headers(self, status_code=200):
|
def _set_headers(self, status_code=200):
|
||||||
self.send_response(status_code)
|
self.send_response(status_code)
|
||||||
@ -186,14 +217,128 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
post_data = self.rfile.read(content_length)
|
post_data = self.rfile.read(content_length)
|
||||||
self._set_headers(200)
|
self._set_headers(200)
|
||||||
|
|
||||||
response = self.handle_post_request(post_data)
|
response = self.handle_chat_completions(post_data)
|
||||||
|
|
||||||
|
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||||
|
elif self.path == "/v1/completions":
|
||||||
|
content_length = int(self.headers["Content-Length"])
|
||||||
|
post_data = self.rfile.read(content_length)
|
||||||
|
self._set_headers(200)
|
||||||
|
|
||||||
|
response = self.handle_completions(post_data)
|
||||||
|
|
||||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||||
else:
|
else:
|
||||||
self._set_headers(404)
|
self._set_headers(404)
|
||||||
self.wfile.write(b"Not Found")
|
self.wfile.write(b"Not Found")
|
||||||
|
|
||||||
def handle_post_request(self, post_data):
|
def generate_response(
|
||||||
|
self,
|
||||||
|
prompt: mx.array,
|
||||||
|
response_id: str,
|
||||||
|
requested_model: str,
|
||||||
|
stop_id_sequences: List[np.ndarray],
|
||||||
|
eos_token_id: int,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
|
||||||
|
):
|
||||||
|
tokens = []
|
||||||
|
for token, _ in zip(
|
||||||
|
generate(
|
||||||
|
prompt,
|
||||||
|
_model,
|
||||||
|
temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
),
|
||||||
|
range(max_tokens),
|
||||||
|
):
|
||||||
|
tokens.append(token)
|
||||||
|
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
|
||||||
|
if stop_condition.stop_met:
|
||||||
|
if stop_condition.trim_length:
|
||||||
|
tokens = tokens[: -stop_condition.trim_length]
|
||||||
|
break
|
||||||
|
|
||||||
|
text = _tokenizer.decode(tokens)
|
||||||
|
return response_creator(response_id, requested_model, prompt, tokens, text)
|
||||||
|
|
||||||
|
def hanlde_stream(
|
||||||
|
self,
|
||||||
|
prompt: mx.array,
|
||||||
|
response_id: str,
|
||||||
|
requested_model: str,
|
||||||
|
stop_id_sequences: List[np.ndarray],
|
||||||
|
eos_token_id: int,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
response_creator: Callable[[str, str, str], dict],
|
||||||
|
):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-type", "text/event-stream")
|
||||||
|
self.send_header("Cache-Control", "no-cache")
|
||||||
|
self.end_headers()
|
||||||
|
max_stop_id_sequence_len = (
|
||||||
|
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
|
||||||
|
)
|
||||||
|
tokens = []
|
||||||
|
current_generated_text_index = 0
|
||||||
|
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
|
||||||
|
stop_sequence_buffer = []
|
||||||
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
for token, _ in zip(
|
||||||
|
generate(
|
||||||
|
prompt,
|
||||||
|
_model,
|
||||||
|
temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
),
|
||||||
|
range(max_tokens),
|
||||||
|
):
|
||||||
|
tokens.append(token)
|
||||||
|
stop_sequence_buffer.append(token)
|
||||||
|
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
|
||||||
|
if REPLACEMENT_CHAR in _tokenizer.decode(token):
|
||||||
|
continue
|
||||||
|
stop_condition = stopping_criteria(
|
||||||
|
tokens,
|
||||||
|
stop_id_sequences,
|
||||||
|
eos_token_id,
|
||||||
|
)
|
||||||
|
if stop_condition.stop_met:
|
||||||
|
if stop_condition.trim_length:
|
||||||
|
tokens = tokens[: -stop_condition.trim_length]
|
||||||
|
break
|
||||||
|
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
|
||||||
|
generated_text = _tokenizer.decode(tokens)
|
||||||
|
next_chunk = generated_text[current_generated_text_index:]
|
||||||
|
current_generated_text_index = len(generated_text)
|
||||||
|
|
||||||
|
response = response_creator(response_id, requested_model, next_chunk)
|
||||||
|
try:
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
stop_sequence_buffer = []
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
break
|
||||||
|
# check is there any remaining text to send
|
||||||
|
if stop_sequence_buffer:
|
||||||
|
generated_text = _tokenizer.decode(tokens)
|
||||||
|
next_chunk = generated_text[current_generated_text_index:]
|
||||||
|
response = response_creator(response_id, requested_model, next_chunk)
|
||||||
|
try:
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
self.wfile.write(f"data: [DONE]\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
|
def handle_chat_completions(self, post_data):
|
||||||
body = json.loads(post_data.decode("utf-8"))
|
body = json.loads(post_data.decode("utf-8"))
|
||||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
||||||
@ -223,91 +368,74 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
temperature = body.get("temperature", 1.0)
|
temperature = body.get("temperature", 1.0)
|
||||||
top_p = body.get("top_p", 1.0)
|
top_p = body.get("top_p", 1.0)
|
||||||
if not stream:
|
if not stream:
|
||||||
tokens = []
|
return self.generate_response(
|
||||||
for token, _ in zip(
|
prompt,
|
||||||
generate(
|
chat_id,
|
||||||
prompt,
|
requested_model,
|
||||||
_model,
|
stop_id_sequences,
|
||||||
temperature,
|
eos_token_id,
|
||||||
top_p=top_p,
|
max_tokens,
|
||||||
),
|
temperature,
|
||||||
range(max_tokens),
|
top_p,
|
||||||
):
|
create_chat_response,
|
||||||
tokens.append(token)
|
)
|
||||||
stop_condition = stopping_criteria(
|
else:
|
||||||
tokens, stop_id_sequences, eos_token_id
|
self.hanlde_stream(
|
||||||
)
|
prompt,
|
||||||
if stop_condition.stop_met:
|
chat_id,
|
||||||
if stop_condition.trim_length:
|
requested_model,
|
||||||
tokens = tokens[: -stop_condition.trim_length]
|
stop_id_sequences,
|
||||||
break
|
eos_token_id,
|
||||||
|
max_tokens,
|
||||||
text = _tokenizer.decode(tokens)
|
temperature,
|
||||||
return create_response(chat_id, requested_model, prompt, tokens, text)
|
top_p,
|
||||||
else:
|
create_chat_chunk_response,
|
||||||
self.send_response(200)
|
|
||||||
self.send_header("Content-type", "text/event-stream")
|
|
||||||
self.send_header("Cache-Control", "no-cache")
|
|
||||||
self.end_headers()
|
|
||||||
max_stop_id_sequence_len = (
|
|
||||||
max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
|
|
||||||
)
|
)
|
||||||
tokens = []
|
|
||||||
current_generated_text_index = 0
|
|
||||||
# Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
|
|
||||||
stop_sequence_buffer = []
|
|
||||||
REPLACEMENT_CHAR = "\ufffd"
|
|
||||||
for token, _ in zip(
|
|
||||||
generate(
|
|
||||||
prompt,
|
|
||||||
_model,
|
|
||||||
temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
),
|
|
||||||
range(max_tokens),
|
|
||||||
):
|
|
||||||
tokens.append(token)
|
|
||||||
stop_sequence_buffer.append(token)
|
|
||||||
if len(stop_sequence_buffer) > max_stop_id_sequence_len:
|
|
||||||
if REPLACEMENT_CHAR in _tokenizer.decode(token):
|
|
||||||
continue
|
|
||||||
stop_condition = stopping_criteria(
|
|
||||||
tokens,
|
|
||||||
stop_id_sequences,
|
|
||||||
eos_token_id,
|
|
||||||
)
|
|
||||||
if stop_condition.stop_met:
|
|
||||||
if stop_condition.trim_length:
|
|
||||||
tokens = tokens[: -stop_condition.trim_length]
|
|
||||||
break
|
|
||||||
# This is a workaround because the llama tokenizer emits spaces when decoding token by token.
|
|
||||||
generated_text = _tokenizer.decode(tokens)
|
|
||||||
next_chunk = generated_text[current_generated_text_index:]
|
|
||||||
current_generated_text_index = len(generated_text)
|
|
||||||
|
|
||||||
response = create_chunk_response(
|
def handle_completions(self, post_data):
|
||||||
chat_id, requested_model, next_chunk
|
body = json.loads(post_data.decode("utf-8"))
|
||||||
)
|
completion_id = f"cmpl-{uuid.uuid4()}"
|
||||||
try:
|
prompt_text = body["prompt"]
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
|
||||||
self.wfile.flush()
|
prompt = mx.array(prompt[0])
|
||||||
stop_sequence_buffer = []
|
stop_words = body.get("stop", [])
|
||||||
except Exception as e:
|
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||||
print(e)
|
stop_id_sequences = [
|
||||||
break
|
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
||||||
# check is there any remaining text to send
|
0
|
||||||
if stop_sequence_buffer:
|
]
|
||||||
generated_text = _tokenizer.decode(tokens)
|
for stop_word in stop_words
|
||||||
next_chunk = generated_text[current_generated_text_index:]
|
]
|
||||||
response = create_chunk_response(chat_id, requested_model, next_chunk)
|
eos_token_id = _tokenizer.eos_token_id
|
||||||
try:
|
max_tokens = body.get("max_tokens", 100)
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
stream = body.get("stream", False)
|
||||||
self.wfile.flush()
|
requested_model = body.get("model", "default_model")
|
||||||
except Exception as e:
|
temperature = body.get("temperature", 1.0)
|
||||||
print(e)
|
top_p = body.get("top_p", 1.0)
|
||||||
|
if not stream:
|
||||||
self.wfile.write(f"data: [DONE]\n\n".encode())
|
return self.generate_response(
|
||||||
self.wfile.flush()
|
prompt,
|
||||||
|
completion_id,
|
||||||
|
requested_model,
|
||||||
|
stop_id_sequences,
|
||||||
|
eos_token_id,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
create_completion_response,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.hanlde_stream(
|
||||||
|
prompt,
|
||||||
|
completion_id,
|
||||||
|
requested_model,
|
||||||
|
stop_id_sequences,
|
||||||
|
eos_token_id,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
create_completion_chunk_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||||
|
Loading…
Reference in New Issue
Block a user