From 88458c4e4024bd1bc6c8b410c8fe088c0cc33095 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Mon, 19 Feb 2024 09:01:28 +1100
Subject: [PATCH] feat(mlx-lm): add openAI like api server (#429)
* feat(mlx-lm): add openAI like api server
* chore: fix sse format
* chore: add top_p support
* chore: fix the load import
* chore: add workground for missing space in stream decoding
* chore: fix typo
* chore: add error handling for streaming
* chore: using slicing instead of replace
* chore: set host, port via args and improve handle stream token logic
* chore: refactor stop sequence function
* chore: rename stopping_criteria
* fix: unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
* chore: fix the streaming unicode issue
* Update llms/mlx_lm/server.py
Co-authored-by: Awni Hannun
* refacotr: move stopping_criteria out of generate func
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/server.py | 349 ++++++++++++++++++++++++++++++++++++++++++
1 file changed, 349 insertions(+)
create mode 100644 llms/mlx_lm/server.py
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
new file mode 100644
index 00000000..da27b8d0
--- /dev/null
+++ b/llms/mlx_lm/server.py
@@ -0,0 +1,349 @@
+import argparse
+import json
+import time
+import uuid
+from collections import namedtuple
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from typing import List, Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+import numpy as np
+from transformers import PreTrainedTokenizer
+
+from .utils import load
+
+_model: Optional[nn.Module] = None
+_tokenizer: Optional[PreTrainedTokenizer] = None
+
+
+def load_model(model_path: str, adapter_file: Optional[str] = None):
+ global _model
+ global _tokenizer
+ _model, _tokenizer = load(model_path, adapter_file=adapter_file)
+
+
+StopCondition = namedtuple("StopCondition", ["stop_met", "trim_length"])
+
+
+def stopping_criteria(
+ tokens: List[int],
+ stop_id_sequences: List[np.ndarray],
+ eos_token_id: int,
+) -> StopCondition:
+ """
+ Determines whether the token generation should stop based on predefined conditions.
+
+ Args:
+ tokens (List[int]): The current sequence of generated tokens.
+ stop_id_sequences (List[np.ndarray]): A list of numpy arrays, each representing a sequence of token IDs.
+ If the end of the `tokens` list matches any of these sequences, the generation should stop.
+ eos_token_id (int): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
+ the generation should stop.
+
+ Returns:
+ StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`)
+ and how many tokens should be trimmed from the end if it has (`trim_length`).
+ """
+ if tokens and tokens[-1] == eos_token_id:
+ return StopCondition(stop_met=True, trim_length=0)
+
+ for stop_ids in stop_id_sequences:
+ if len(tokens) >= len(stop_ids):
+ if np.array_equal(tokens[-len(stop_ids) :], stop_ids):
+ return StopCondition(stop_met=True, trim_length=len(stop_ids))
+
+ return StopCondition(stop_met=False, trim_length=0)
+
+
+def generate(
+ prompt: mx.array,
+ model: nn.Module,
+ temp: float = 0.0,
+ top_p: float = 1.0,
+):
+ def sample(logits):
+ if temp == 0:
+ return mx.argmax(logits, axis=-1)
+ else:
+ if top_p > 0 and top_p < 1.0:
+ if (
+ logits.dtype == mx.bfloat16
+ ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
+ logits = logits.astype(mx.float32)
+ probs = mx.softmax(logits / temp, axis=-1)
+
+ sorted_probs = mx.sort(probs)[::-1]
+ sorted_indices = mx.argsort(probs)[::-1]
+ cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
+
+ top_probs = mx.where(
+ cumulative_probs > 1 - top_p,
+ sorted_probs,
+ mx.zeros_like(sorted_probs),
+ )
+ sorted_tok = mx.random.categorical(mx.log(top_probs))
+ tok = sorted_indices.squeeze(0)[sorted_tok]
+ return tok
+ return mx.random.categorical(logits * (1 / temp))
+
+ y = prompt
+ cache = None
+
+ while True:
+ logits, cache = model(y[None], cache=cache)
+ logits = logits[:, -1, :]
+
+ y = sample(logits)
+ token = y.item()
+
+ yield token
+
+
+def convert_chat(messages: any, role_mapping: Optional[dict] = None):
+ default_role_mapping = {
+ "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
+ "system": "ASSISTANT's RULE: ",
+ "user": "USER: ",
+ "assistant": "ASSISTANT: ",
+ "stop": "\n",
+ }
+ role_mapping = role_mapping if role_mapping is not None else default_role_mapping
+
+ prompt = ""
+ for line in messages:
+ role_prefix = role_mapping.get(line["role"], "")
+ stop = role_mapping.get("stop", "")
+ content = line.get("content", "")
+ prompt += f"{role_prefix}{content}{stop}"
+
+ prompt += role_mapping.get("assistant", "")
+ return prompt.rstrip()
+
+
+def create_response(chat_id, requested_model, prompt, tokens, text):
+ response = {
+ "id": chat_id,
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": requested_model,
+ "system_fingerprint": f"fp_{uuid.uuid4()}",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": text,
+ },
+ "logprobs": None,
+ "finish_reason": None,
+ }
+ ],
+ "usage": {
+ "prompt_tokens": len(prompt),
+ "completion_tokens": len(tokens),
+ "total_tokens": len(prompt) + len(tokens),
+ },
+ }
+
+ return response
+
+
+def create_chunk_response(chat_id, requested_model, next_chunk):
+ response = {
+ "id": chat_id,
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": requested_model,
+ "system_fingerprint": f"fp_{uuid.uuid4()}",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": next_chunk},
+ "logprobs": None,
+ "finish_reason": None,
+ }
+ ],
+ }
+ return response
+
+
+class APIHandler(BaseHTTPRequestHandler):
+ def _set_headers(self, status_code=200):
+ self.send_response(status_code)
+ self.send_header("Content-type", "application/json")
+ self.send_header("Access-Control-Allow-Origin", "*")
+ self.send_header("Access-Control-Allow-Methods", "*")
+ self.send_header("Access-Control-Allow-Headers", "*")
+ self.end_headers()
+
+ def do_OPTIONS(self):
+ self._set_headers(204)
+
+ def do_POST(self):
+ if self.path == "/v1/chat/completions":
+ content_length = int(self.headers["Content-Length"])
+ post_data = self.rfile.read(content_length)
+ self._set_headers(200)
+
+ response = self.handle_post_request(post_data)
+
+ self.wfile.write(json.dumps(response).encode("utf-8"))
+ else:
+ self._set_headers(404)
+ self.wfile.write(b"Not Found")
+
+ def handle_post_request(self, post_data):
+ body = json.loads(post_data.decode("utf-8"))
+ chat_id = f"chatcmpl-{uuid.uuid4()}"
+ if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
+ prompt = _tokenizer.apply_chat_template(
+ body["messages"],
+ tokenize=True,
+ add_generation_prompt=True,
+ return_tensors="np",
+ )
+ else:
+ prompt = convert_chat(body["messages"], body.get("role_mapping"))
+ prompt = _tokenizer.encode(prompt, return_tensors="np")
+
+ prompt = mx.array(prompt[0])
+ stop_words = body.get("stop", [])
+ stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
+ stop_id_sequences = [
+ _tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[
+ 0
+ ]
+ for stop_word in stop_words
+ ]
+ eos_token_id = _tokenizer.eos_token_id
+ max_tokens = body.get("max_tokens", 100)
+ stream = body.get("stream", False)
+ requested_model = body.get("model", "default_model")
+ temperature = body.get("temperature", 1.0)
+ top_p = body.get("top_p", 1.0)
+ if not stream:
+ tokens = []
+ for token, _ in zip(
+ generate(
+ prompt,
+ _model,
+ temperature,
+ top_p=top_p,
+ ),
+ range(max_tokens),
+ ):
+ tokens.append(token)
+ stop_condition = stopping_criteria(
+ tokens, stop_id_sequences, eos_token_id
+ )
+ if stop_condition.stop_met:
+ if stop_condition.trim_length:
+ tokens = tokens[: -stop_condition.trim_length]
+ break
+
+ text = _tokenizer.decode(tokens)
+ return create_response(chat_id, requested_model, prompt, tokens, text)
+ else:
+ self.send_response(200)
+ self.send_header("Content-type", "text/event-stream")
+ self.send_header("Cache-Control", "no-cache")
+ self.end_headers()
+ max_stop_id_sequence_len = (
+ max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0
+ )
+ tokens = []
+ current_generated_text_index = 0
+ # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream.
+ stop_sequence_buffer = []
+ REPLACEMENT_CHAR = "\ufffd"
+ for token, _ in zip(
+ generate(
+ prompt,
+ _model,
+ temperature,
+ top_p=top_p,
+ ),
+ range(max_tokens),
+ ):
+ tokens.append(token)
+ stop_sequence_buffer.append(token)
+ if len(stop_sequence_buffer) > max_stop_id_sequence_len:
+ if REPLACEMENT_CHAR in _tokenizer.decode(token):
+ continue
+ stop_condition = stopping_criteria(
+ tokens,
+ stop_id_sequences,
+ eos_token_id,
+ )
+ if stop_condition.stop_met:
+ if stop_condition.trim_length:
+ tokens = tokens[: -stop_condition.trim_length]
+ break
+ # This is a workaround because the llama tokenizer emits spaces when decoding token by token.
+ generated_text = _tokenizer.decode(tokens)
+ next_chunk = generated_text[current_generated_text_index:]
+ current_generated_text_index = len(generated_text)
+
+ response = create_chunk_response(
+ chat_id, requested_model, next_chunk
+ )
+ try:
+ self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
+ self.wfile.flush()
+ stop_sequence_buffer = []
+ except Exception as e:
+ print(e)
+ break
+ # check is there any remaining text to send
+ if stop_sequence_buffer:
+ generated_text = _tokenizer.decode(tokens)
+ next_chunk = generated_text[current_generated_text_index:]
+ response = create_chunk_response(chat_id, requested_model, next_chunk)
+ try:
+ self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
+ self.wfile.flush()
+ except Exception as e:
+ print(e)
+
+ self.wfile.write(f"data: [DONE]\n\n".encode())
+ self.wfile.flush()
+
+
+def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
+ server_address = (host, port)
+ httpd = server_class(server_address, handler_class)
+ print(f"Starting httpd at {host} on port {port}...")
+ httpd.serve_forever()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="MLX Http Server.")
+ parser.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="The path to the MLX model weights, tokenizer, and config",
+ )
+ parser.add_argument(
+ "--adapter-file",
+ type=str,
+ help="Optional path for the trained adapter weights.",
+ )
+ parser.add_argument(
+ "--host",
+ type=str,
+ default="127.0.0.1",
+ help="Host for the HTTP server (default: 127.0.0.1)",
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8080,
+ help="Port for the HTTP server (default: 8080)",
+ )
+ args = parser.parse_args()
+
+ load_model(args.model, adapter_file=args.adapter_file)
+
+ run(args.host, args.port)