From 88458c4e4024bd1bc6c8b410c8fe088c0cc33095 Mon Sep 17 00:00:00 2001 From: Anchen Date: Mon, 19 Feb 2024 09:01:28 +1100 Subject: [PATCH] feat(mlx-lm): add openAI like api server (#429) * feat(mlx-lm): add openAI like api server * chore: fix sse format * chore: add top_p support * chore: fix the load import * chore: add workground for missing space in stream decoding * chore: fix typo * chore: add error handling for streaming * chore: using slicing instead of replace * chore: set host, port via args and improve handle stream token logic * chore: refactor stop sequence function * chore: rename stopping_criteria * fix: unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 * chore: fix the streaming unicode issue * Update llms/mlx_lm/server.py Co-authored-by: Awni Hannun * refacotr: move stopping_criteria out of generate func --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/server.py | 349 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 349 insertions(+) create mode 100644 llms/mlx_lm/server.py diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py new file mode 100644 index 00000000..da27b8d0 --- /dev/null +++ b/llms/mlx_lm/server.py @@ -0,0 +1,349 @@ +import argparse +import json +import time +import uuid +from collections import namedtuple +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from transformers import PreTrainedTokenizer + +from .utils import load + +_model: Optional[nn.Module] = None +_tokenizer: Optional[PreTrainedTokenizer] = None + + +def load_model(model_path: str, adapter_file: Optional[str] = None): + global _model + global _tokenizer + _model, _tokenizer = load(model_path, adapter_file=adapter_file) + + +StopCondition = namedtuple("StopCondition", ["stop_met", "trim_length"]) + + +def stopping_criteria( + tokens: List[int], + stop_id_sequences: List[np.ndarray], + eos_token_id: int, +) -> StopCondition: + """ + Determines whether the token generation should stop based on predefined conditions. + + Args: + tokens (List[int]): The current sequence of generated tokens. + stop_id_sequences (List[np.ndarray]): A list of numpy arrays, each representing a sequence of token IDs. + If the end of the `tokens` list matches any of these sequences, the generation should stop. + eos_token_id (int): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this, + the generation should stop. + + Returns: + StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`) + and how many tokens should be trimmed from the end if it has (`trim_length`). + """ + if tokens and tokens[-1] == eos_token_id: + return StopCondition(stop_met=True, trim_length=0) + + for stop_ids in stop_id_sequences: + if len(tokens) >= len(stop_ids): + if np.array_equal(tokens[-len(stop_ids) :], stop_ids): + return StopCondition(stop_met=True, trim_length=len(stop_ids)) + + return StopCondition(stop_met=False, trim_length=0) + + +def generate( + prompt: mx.array, + model: nn.Module, + temp: float = 0.0, + top_p: float = 1.0, +): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + if top_p > 0 and top_p < 1.0: + if ( + logits.dtype == mx.bfloat16 + ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 + logits = logits.astype(mx.float32) + probs = mx.softmax(logits / temp, axis=-1) + + sorted_probs = mx.sort(probs)[::-1] + sorted_indices = mx.argsort(probs)[::-1] + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + + top_probs = mx.where( + cumulative_probs > 1 - top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + sorted_tok = mx.random.categorical(mx.log(top_probs)) + tok = sorted_indices.squeeze(0)[sorted_tok] + return tok + return mx.random.categorical(logits * (1 / temp)) + + y = prompt + cache = None + + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + + y = sample(logits) + token = y.item() + + yield token + + +def convert_chat(messages: any, role_mapping: Optional[dict] = None): + default_role_mapping = { + "system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.", + "system": "ASSISTANT's RULE: ", + "user": "USER: ", + "assistant": "ASSISTANT: ", + "stop": "\n", + } + role_mapping = role_mapping if role_mapping is not None else default_role_mapping + + prompt = "" + for line in messages: + role_prefix = role_mapping.get(line["role"], "") + stop = role_mapping.get("stop", "") + content = line.get("content", "") + prompt += f"{role_prefix}{content}{stop}" + + prompt += role_mapping.get("assistant", "") + return prompt.rstrip() + + +def create_response(chat_id, requested_model, prompt, tokens, text): + response = { + "id": chat_id, + "object": "chat.completion", + "created": int(time.time()), + "model": requested_model, + "system_fingerprint": f"fp_{uuid.uuid4()}", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": text, + }, + "logprobs": None, + "finish_reason": None, + } + ], + "usage": { + "prompt_tokens": len(prompt), + "completion_tokens": len(tokens), + "total_tokens": len(prompt) + len(tokens), + }, + } + + return response + + +def create_chunk_response(chat_id, requested_model, next_chunk): + response = { + "id": chat_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": requested_model, + "system_fingerprint": f"fp_{uuid.uuid4()}", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": next_chunk}, + "logprobs": None, + "finish_reason": None, + } + ], + } + return response + + +class APIHandler(BaseHTTPRequestHandler): + def _set_headers(self, status_code=200): + self.send_response(status_code) + self.send_header("Content-type", "application/json") + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "*") + self.send_header("Access-Control-Allow-Headers", "*") + self.end_headers() + + def do_OPTIONS(self): + self._set_headers(204) + + def do_POST(self): + if self.path == "/v1/chat/completions": + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length) + self._set_headers(200) + + response = self.handle_post_request(post_data) + + self.wfile.write(json.dumps(response).encode("utf-8")) + else: + self._set_headers(404) + self.wfile.write(b"Not Found") + + def handle_post_request(self, post_data): + body = json.loads(post_data.decode("utf-8")) + chat_id = f"chatcmpl-{uuid.uuid4()}" + if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: + prompt = _tokenizer.apply_chat_template( + body["messages"], + tokenize=True, + add_generation_prompt=True, + return_tensors="np", + ) + else: + prompt = convert_chat(body["messages"], body.get("role_mapping")) + prompt = _tokenizer.encode(prompt, return_tensors="np") + + prompt = mx.array(prompt[0]) + stop_words = body.get("stop", []) + stop_words = [stop_words] if isinstance(stop_words, str) else stop_words + stop_id_sequences = [ + _tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[ + 0 + ] + for stop_word in stop_words + ] + eos_token_id = _tokenizer.eos_token_id + max_tokens = body.get("max_tokens", 100) + stream = body.get("stream", False) + requested_model = body.get("model", "default_model") + temperature = body.get("temperature", 1.0) + top_p = body.get("top_p", 1.0) + if not stream: + tokens = [] + for token, _ in zip( + generate( + prompt, + _model, + temperature, + top_p=top_p, + ), + range(max_tokens), + ): + tokens.append(token) + stop_condition = stopping_criteria( + tokens, stop_id_sequences, eos_token_id + ) + if stop_condition.stop_met: + if stop_condition.trim_length: + tokens = tokens[: -stop_condition.trim_length] + break + + text = _tokenizer.decode(tokens) + return create_response(chat_id, requested_model, prompt, tokens, text) + else: + self.send_response(200) + self.send_header("Content-type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.end_headers() + max_stop_id_sequence_len = ( + max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0 + ) + tokens = [] + current_generated_text_index = 0 + # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream. + stop_sequence_buffer = [] + REPLACEMENT_CHAR = "\ufffd" + for token, _ in zip( + generate( + prompt, + _model, + temperature, + top_p=top_p, + ), + range(max_tokens), + ): + tokens.append(token) + stop_sequence_buffer.append(token) + if len(stop_sequence_buffer) > max_stop_id_sequence_len: + if REPLACEMENT_CHAR in _tokenizer.decode(token): + continue + stop_condition = stopping_criteria( + tokens, + stop_id_sequences, + eos_token_id, + ) + if stop_condition.stop_met: + if stop_condition.trim_length: + tokens = tokens[: -stop_condition.trim_length] + break + # This is a workaround because the llama tokenizer emits spaces when decoding token by token. + generated_text = _tokenizer.decode(tokens) + next_chunk = generated_text[current_generated_text_index:] + current_generated_text_index = len(generated_text) + + response = create_chunk_response( + chat_id, requested_model, next_chunk + ) + try: + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + stop_sequence_buffer = [] + except Exception as e: + print(e) + break + # check is there any remaining text to send + if stop_sequence_buffer: + generated_text = _tokenizer.decode(tokens) + next_chunk = generated_text[current_generated_text_index:] + response = create_chunk_response(chat_id, requested_model, next_chunk) + try: + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + except Exception as e: + print(e) + + self.wfile.write(f"data: [DONE]\n\n".encode()) + self.wfile.flush() + + +def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): + server_address = (host, port) + httpd = server_class(server_address, handler_class) + print(f"Starting httpd at {host} on port {port}...") + httpd.serve_forever() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MLX Http Server.") + parser.add_argument( + "--model", + type=str, + required=True, + help="The path to the MLX model weights, tokenizer, and config", + ) + parser.add_argument( + "--adapter-file", + type=str, + help="Optional path for the trained adapter weights.", + ) + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="Host for the HTTP server (default: 127.0.0.1)", + ) + parser.add_argument( + "--port", + type=int, + default=8080, + help="Port for the HTTP server (default: 8080)", + ) + args = parser.parse_args() + + load_model(args.model, adapter_file=args.adapter_file) + + run(args.host, args.port)