From 19a21bfce4f0905404fabfc83b695b98b5bde09f Mon Sep 17 00:00:00 2001
From: Anchen
Date: Tue, 27 Feb 2024 15:59:33 +1100
Subject: [PATCH] chore: add /v1/completions for server (#489)
---
llms/mlx_lm/server.py | 306 ++++++++++++++++++++++++++++++------------
1 file changed, 217 insertions(+), 89 deletions(-)
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index da27b8d0..a679216c 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -4,7 +4,7 @@ import time
import uuid
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer
-from typing import List, Optional, Tuple
+from typing import Callable, List, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -46,7 +46,7 @@ def stopping_criteria(
and how many tokens should be trimmed from the end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
- return StopCondition(stop_met=True, trim_length=0)
+ return StopCondition(stop_met=True, trim_length=1)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
@@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None):
return prompt.rstrip()
-def create_response(chat_id, requested_model, prompt, tokens, text):
+def create_chat_response(chat_id, requested_model, prompt, tokens, text):
response = {
"id": chat_id,
"object": "chat.completion",
@@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text):
return response
-def create_chunk_response(chat_id, requested_model, next_chunk):
+def create_completion_response(completion_id, requested_model, prompt, tokens, text):
+ return {
+ "id": completion_id,
+ "object": "text_completion",
+ "created": int(time.time()),
+ "model": requested_model,
+ "system_fingerprint": f"fp_{uuid.uuid4()}",
+ "choices": [
+ {"text": text, "index": 0, "logprobs": None, "finish_reason": "length"}
+ ],
+ "usage": {
+ "prompt_tokens": len(prompt),
+ "completion_tokens": len(tokens),
+ "total_tokens": len(prompt) + len(tokens),
+ },
+ }
+
+
+def create_chat_chunk_response(chat_id, requested_model, next_chunk):
response = {
"id": chat_id,
"object": "chat.completion.chunk",
@@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk):
return response
+def create_completion_chunk_response(completion_id, requested_model, next_chunk):
+ return {
+ "id": completion_id,
+ "object": "text_completion",
+ "created": int(time.time()),
+ "choices": [
+ {"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None}
+ ],
+ "model": requested_model,
+ "system_fingerprint": f"fp_{uuid.uuid4()}",
+ }
+
+
class APIHandler(BaseHTTPRequestHandler):
def _set_headers(self, status_code=200):
self.send_response(status_code)
@@ -186,14 +217,128 @@ class APIHandler(BaseHTTPRequestHandler):
post_data = self.rfile.read(content_length)
self._set_headers(200)
- response = self.handle_post_request(post_data)
+ response = self.handle_chat_completions(post_data)
+
+ self.wfile.write(json.dumps(response).encode("utf-8"))
+ elif self.path == "/v1/completions":
+ content_length = int(self.headers["Content-Length"])
+ post_data = self.rfile.read(content_length)
+ self._set_headers(200)
+
+ response = self.handle_completions(post_data)
self.wfile.write(json.dumps(response).encode("utf-8"))
else:
self._set_headers(404)
self.wfile.write(b"Not Found")
- def handle_post_request(self, post_data):
+ def generate_response(
+ self,
+ prompt: mx.array,
+ response_id: str,
+ requested_model: str,
+ stop_id_sequences: List[np.ndarray],
+ eos_token_id: int,
+ max_tokens: int,
+ temperature: float,
+ top_p: float,
+ response_creator: Callable[[str, str, mx.array, List[int], str], dict],
+ ):
+ tokens = []
+ for token, _ in zip(
+ generate(
+ prompt,
+ _model,
+ temperature,
+ top_p=top_p,
+ ),
+ range(max_tokens),
+ ):
+ tokens.append(token)
+ stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id)
+ if stop_condition.stop_met:
+ if stop_condition.trim_length:
+ tokens = tokens[: -stop_condition.trim_length]
+ break
+
+ text = _tokenizer.decode(tokens)
+ return response_creator(response_id, requested_model, prompt, tokens, text)
+
+ def hanlde_stream(
+ self,
+ prompt: mx.array,
+ response_id: str,
+ requested_model: str,
+ stop_id_sequences: List[np.ndarray],
+ eos_token_id: int,
+ max_tokens: int,
+ temperature: float,
+ top_p: float,
+ response_creator: Callable[[str, str, str], dict],
+ ):
+ self.send_response(200)
+ self.send_header("Content-type", "text/event-stream")
+ self.send_header("Cache-Control", "no-cache")
+ self.end_headers()
+ max_stop_id_sequence_len = (
+ max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
+ )
+ tokens = []
+ current_generated_text_index = 0
+ # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
+ stop_sequence_buffer = []
+ REPLACEMENT_CHAR = "\ufffd"
+ for token, _ in zip(
+ generate(
+ prompt,
+ _model,
+ temperature,
+ top_p=top_p,
+ ),
+ range(max_tokens),
+ ):
+ tokens.append(token)
+ stop_sequence_buffer.append(token)
+ if len(stop_sequence_buffer) > max_stop_id_sequence_len:
+ if REPLACEMENT_CHAR in _tokenizer.decode(token):
+ continue
+ stop_condition = stopping_criteria(
+ tokens,
+ stop_id_sequences,
+ eos_token_id,
+ )
+ if stop_condition.stop_met:
+ if stop_condition.trim_length:
+ tokens = tokens[: -stop_condition.trim_length]
+ break
+ # This is a workaround because the llama tokenizer emits spaces when decoding token by token.
+ generated_text = _tokenizer.decode(tokens)
+ next_chunk = generated_text[current_generated_text_index:]
+ current_generated_text_index = len(generated_text)
+
+ response = response_creator(response_id, requested_model, next_chunk)
+ try:
+ self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
+ self.wfile.flush()
+ stop_sequence_buffer = []
+ except Exception as e:
+ print(e)
+ break
+ # check is there any remaining text to send
+ if stop_sequence_buffer:
+ generated_text = _tokenizer.decode(tokens)
+ next_chunk = generated_text[current_generated_text_index:]
+ response = response_creator(response_id, requested_model, next_chunk)
+ try:
+ self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
+ self.wfile.flush()
+ except Exception as e:
+ print(e)
+
+ self.wfile.write(f"data: [DONE]\n\n".encode())
+ self.wfile.flush()
+
+ def handle_chat_completions(self, post_data):
body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
@@ -223,91 +368,74 @@ class APIHandler(BaseHTTPRequestHandler):
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
if not stream:
- tokens = []
- for token, _ in zip(
- generate(
- prompt,
- _model,
- temperature,
- top_p=top_p,
- ),
- range(max_tokens),
- ):
- tokens.append(token)
- stop_condition = stopping_criteria(
- tokens, stop_id_sequences, eos_token_id
- )
- if stop_condition.stop_met:
- if stop_condition.trim_length:
- tokens = tokens[: -stop_condition.trim_length]
- break
-
- text = _tokenizer.decode(tokens)
- return create_response(chat_id, requested_model, prompt, tokens, text)
- else:
- self.send_response(200)
- self.send_header("Content-type", "text/event-stream")
- self.send_header("Cache-Control", "no-cache")
- self.end_headers()
- max_stop_id_sequence_len = (
- max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
+ return self.generate_response(
+ prompt,
+ chat_id,
+ requested_model,
+ stop_id_sequences,
+ eos_token_id,
+ max_tokens,
+ temperature,
+ top_p,
+ create_chat_response,
+ )
+ else:
+ self.hanlde_stream(
+ prompt,
+ chat_id,
+ requested_model,
+ stop_id_sequences,
+ eos_token_id,
+ max_tokens,
+ temperature,
+ top_p,
+ create_chat_chunk_response,
)
- tokens = []
- current_generated_text_index = 0
- # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
- stop_sequence_buffer = []
- REPLACEMENT_CHAR = "\ufffd"
- for token, _ in zip(
- generate(
- prompt,
- _model,
- temperature,
- top_p=top_p,
- ),
- range(max_tokens),
- ):
- tokens.append(token)
- stop_sequence_buffer.append(token)
- if len(stop_sequence_buffer) > max_stop_id_sequence_len:
- if REPLACEMENT_CHAR in _tokenizer.decode(token):
- continue
- stop_condition = stopping_criteria(
- tokens,
- stop_id_sequences,
- eos_token_id,
- )
- if stop_condition.stop_met:
- if stop_condition.trim_length:
- tokens = tokens[: -stop_condition.trim_length]
- break
- # This is a workaround because the llama tokenizer emits spaces when decoding token by token.
- generated_text = _tokenizer.decode(tokens)
- next_chunk = generated_text[current_generated_text_index:]
- current_generated_text_index = len(generated_text)
- response = create_chunk_response(
- chat_id, requested_model, next_chunk
- )
- try:
- self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
- self.wfile.flush()
- stop_sequence_buffer = []
- except Exception as e:
- print(e)
- break
- # check is there any remaining text to send
- if stop_sequence_buffer:
- generated_text = _tokenizer.decode(tokens)
- next_chunk = generated_text[current_generated_text_index:]
- response = create_chunk_response(chat_id, requested_model, next_chunk)
- try:
- self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
- self.wfile.flush()
- except Exception as e:
- print(e)
-
- self.wfile.write(f"data: [DONE]\n\n".encode())
- self.wfile.flush()
+ def handle_completions(self, post_data):
+ body = json.loads(post_data.decode("utf-8"))
+ completion_id = f"cmpl-{uuid.uuid4()}"
+ prompt_text = body["prompt"]
+ prompt = _tokenizer.encode(prompt_text, return_tensors="np")
+ prompt = mx.array(prompt[0])
+ stop_words = body.get("stop", [])
+ stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
+ stop_id_sequences = [
+ _tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
+ 0
+ ]
+ for stop_word in stop_words
+ ]
+ eos_token_id = _tokenizer.eos_token_id
+ max_tokens = body.get("max_tokens", 100)
+ stream = body.get("stream", False)
+ requested_model = body.get("model", "default_model")
+ temperature = body.get("temperature", 1.0)
+ top_p = body.get("top_p", 1.0)
+ if not stream:
+ return self.generate_response(
+ prompt,
+ completion_id,
+ requested_model,
+ stop_id_sequences,
+ eos_token_id,
+ max_tokens,
+ temperature,
+ top_p,
+ create_completion_response,
+ )
+ else:
+ self.hanlde_stream(
+ prompt,
+ completion_id,
+ requested_model,
+ stop_id_sequences,
+ eos_token_id,
+ max_tokens,
+ temperature,
+ top_p,
+ create_completion_chunk_response,
+ )
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):