mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
chore: add /v1/completions for server (#489)
This commit is contained in:
parent
e5dfef5d9a
commit
19a21bfce4
@ -4,7 +4,7 @@ import time
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -46,7 +46,7 @@ def stopping_criteria(
|
||||
and how many tokens should be trimmed from the end if it has (`trim_length`).
|
||||
"""
|
||||
if tokens and tokens[-1] == eos_token_id:
|
||||
return StopCondition(stop_met=True, trim_length=0)
|
||||
return StopCondition(stop_met=True, trim_length=1)
|
||||
|
||||
for stop_ids in stop_id_sequences:
|
||||
if len(tokens) >= len(stop_ids):
|
||||
@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
def create_response(chat_id, requested_model, prompt, tokens, text):
|
||||
def create_chat_response(chat_id, requested_model, prompt, tokens, text):
|
||||
response = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion",
|
||||
@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
|
||||
return response
|
||||
|
||||
|
||||
def create_chunk_response(chat_id, requested_model, next_chunk):
|
||||
def create_completion_response(completion_id, requested_model, prompt, tokens, text):
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"choices": [
|
||||
{"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(prompt),
|
||||
"completion_tokens": len(tokens),
|
||||
"total_tokens": len(prompt) + len(tokens),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_chat_chunk_response(chat_id, requested_model, next_chunk):
|
||||
response = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
|
||||
return response
|
||||
|
||||
|
||||
def create_completion_chunk_response(completion_id, requested_model, next_chunk):
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"choices": [
|
||||
{"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
|
||||
],
|
||||
"model": requested_model,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
}
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def _set_headers(self, status_code=200):
|
||||
self.send_response(status_code)
|
||||
@ -186,43 +217,33 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
post_data = self.rfile.read(content_length)
|
||||
self._set_headers(200)
|
||||
|
||||
response = self.handle_post_request(post_data)
|
||||
response = self.handle_chat_completions(post_data)
|
||||
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
elif self.path == "/v1/completions":
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
post_data = self.rfile.read(content_length)
|
||||
self._set_headers(200)
|
||||
|
||||
response = self.handle_completions(post_data)
|
||||
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
else:
|
||||
self._set_headers(404)
|
||||
self.wfile.write(b"Not Found")
|
||||
|
||||
def handle_post_request(self, post_data):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
||||
prompt = _tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
else:
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = _tokenizer.encode(prompt, return_tensors="np")
|
||||
|
||||
prompt = mx.array(prompt[0])
|
||||
stop_words = body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
for stop_word in stop_words
|
||||
]
|
||||
eos_token_id = _tokenizer.eos_token_id
|
||||
max_tokens = body.get("max_tokens", 100)
|
||||
stream = body.get("stream", False)
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
if not stream:
|
||||
def generate_response(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
response_id: str,
|
||||
requested_model: str,
|
||||
stop_id_sequences: List[np.ndarray],
|
||||
eos_token_id: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
response_creator: Callable[[str, str, mx.array, List[int], str], dict],
|
||||
):
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
generate(
|
||||
@ -234,17 +255,27 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
range(max_tokens),
|
||||
):
|
||||
tokens.append(token)
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, eos_token_id
|
||||
)
|
||||
stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
break
|
||||
|
||||
text = _tokenizer.decode(tokens)
|
||||
return create_response(chat_id, requested_model, prompt, tokens, text)
|
||||
else:
|
||||
return response_creator(response_id, requested_model, prompt, tokens, text)
|
||||
|
||||
def hanlde_stream(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
response_id: str,
|
||||
requested_model: str,
|
||||
stop_id_sequences: List[np.ndarray],
|
||||
eos_token_id: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
response_creator: Callable[[str, str, str], dict],
|
||||
):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
@ -285,9 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
current_generated_text_index = len(generated_text)
|
||||
|
||||
response = create_chunk_response(
|
||||
chat_id, requested_model, next_chunk
|
||||
)
|
||||
response = response_creator(response_id, requested_model, next_chunk)
|
||||
try:
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
@ -299,7 +328,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
if stop_sequence_buffer:
|
||||
generated_text = _tokenizer.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
response = create_chunk_response(chat_id, requested_model, next_chunk)
|
||||
response = response_creator(response_id, requested_model, next_chunk)
|
||||
try:
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
@ -309,6 +338,105 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.wfile.write(f"data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
def handle_chat_completions(self, post_data):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
||||
prompt = _tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
else:
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = _tokenizer.encode(prompt, return_tensors="np")
|
||||
|
||||
prompt = mx.array(prompt[0])
|
||||
stop_words = body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
for stop_word in stop_words
|
||||
]
|
||||
eos_token_id = _tokenizer.eos_token_id
|
||||
max_tokens = body.get("max_tokens", 100)
|
||||
stream = body.get("stream", False)
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
if not stream:
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
chat_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_chat_response,
|
||||
)
|
||||
else:
|
||||
self.hanlde_stream(
|
||||
prompt,
|
||||
chat_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_chat_chunk_response,
|
||||
)
|
||||
|
||||
def handle_completions(self, post_data):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
completion_id = f"cmpl-{uuid.uuid4()}"
|
||||
prompt_text = body["prompt"]
|
||||
prompt = _tokenizer.encode(prompt_text, return_tensors="np")
|
||||
prompt = mx.array(prompt[0])
|
||||
stop_words = body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
_tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
for stop_word in stop_words
|
||||
]
|
||||
eos_token_id = _tokenizer.eos_token_id
|
||||
max_tokens = body.get("max_tokens", 100)
|
||||
stream = body.get("stream", False)
|
||||
requested_model = body.get("model", "default_model")
|
||||
temperature = body.get("temperature", 1.0)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
if not stream:
|
||||
return self.generate_response(
|
||||
prompt,
|
||||
completion_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_completion_response,
|
||||
)
|
||||
else:
|
||||
self.hanlde_stream(
|
||||
prompt,
|
||||
completion_id,
|
||||
requested_model,
|
||||
stop_id_sequences,
|
||||
eos_token_id,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
create_completion_chunk_response,
|
||||
)
|
||||
|
||||
|
||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||
server_address = (host, port)
|
||||
|
Loading…
Reference in New Issue
Block a user