diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index da27b8d0..a679216c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -4,7 +4,7 @@ import time import uuid from collections import namedtuple from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import List, Optional, Tuple +from typing import Callable, List, Optional import mlx.core as mx import mlx.nn as nn @@ -46,7 +46,7 @@ def stopping_criteria( and how many tokens should be trimmed from the end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=0) + return StopCondition(stop_met=True, trim_length=1) for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): @@ -121,7 +121,7 @@ def convert_chat(messages: any, role_mapping: Optional[dict] = None): return prompt.rstrip() -def create_response(chat_id, requested_model, prompt, tokens, text): +def create_chat_response(chat_id, requested_model, prompt, tokens, text): response = { "id": chat_id, "object": "chat.completion", @@ -149,7 +149,25 @@ def create_response(chat_id, requested_model, prompt, tokens, text): return response -def create_chunk_response(chat_id, requested_model, next_chunk): +def create_completion_response(completion_id, requested_model, prompt, tokens, text): + return { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": requested_model, + "system_fingerprint": f"fp_{uuid.uuid4()}", + "choices": [ + {"text": text, "index": 0, "logprobs": None, "finish_reason": "length"} + ], + "usage": { + "prompt_tokens": len(prompt), + "completion_tokens": len(tokens), + "total_tokens": len(prompt) + len(tokens), + }, + } + + +def create_chat_chunk_response(chat_id, requested_model, next_chunk): response = { "id": chat_id, "object": "chat.completion.chunk", @@ -168,6 +186,19 @@ def create_chunk_response(chat_id, requested_model, next_chunk): return response +def create_completion_chunk_response(completion_id, requested_model, next_chunk): + return { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "choices": [ + {"text": next_chunk, "index": 0, "logprobs": None, "finish_reason": None} + ], + "model": requested_model, + "system_fingerprint": f"fp_{uuid.uuid4()}", + } + + class APIHandler(BaseHTTPRequestHandler): def _set_headers(self, status_code=200): self.send_response(status_code) @@ -186,14 +217,128 @@ class APIHandler(BaseHTTPRequestHandler): post_data = self.rfile.read(content_length) self._set_headers(200) - response = self.handle_post_request(post_data) + response = self.handle_chat_completions(post_data) + + self.wfile.write(json.dumps(response).encode("utf-8")) + elif self.path == "/v1/completions": + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length) + self._set_headers(200) + + response = self.handle_completions(post_data) self.wfile.write(json.dumps(response).encode("utf-8")) else: self._set_headers(404) self.wfile.write(b"Not Found") - def handle_post_request(self, post_data): + def generate_response( + self, + prompt: mx.array, + response_id: str, + requested_model: str, + stop_id_sequences: List[np.ndarray], + eos_token_id: int, + max_tokens: int, + temperature: float, + top_p: float, + response_creator: Callable[[str, str, mx.array, List[int], str], dict], + ): + tokens = [] + for token, _ in zip( + generate( + prompt, + _model, + temperature, + top_p=top_p, + ), + range(max_tokens), + ): + tokens.append(token) + stop_condition = stopping_criteria(tokens, stop_id_sequences, eos_token_id) + if stop_condition.stop_met: + if stop_condition.trim_length: + tokens = tokens[: -stop_condition.trim_length] + break + + text = _tokenizer.decode(tokens) + return response_creator(response_id, requested_model, prompt, tokens, text) + + def hanlde_stream( + self, + prompt: mx.array, + response_id: str, + requested_model: str, + stop_id_sequences: List[np.ndarray], + eos_token_id: int, + max_tokens: int, + temperature: float, + top_p: float, + response_creator: Callable[[str, str, str], dict], + ): + self.send_response(200) + self.send_header("Content-type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.end_headers() + max_stop_id_sequence_len = ( + max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0 + ) + tokens = [] + current_generated_text_index = 0 + # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream. + stop_sequence_buffer = [] + REPLACEMENT_CHAR = "\ufffd" + for token, _ in zip( + generate( + prompt, + _model, + temperature, + top_p=top_p, + ), + range(max_tokens), + ): + tokens.append(token) + stop_sequence_buffer.append(token) + if len(stop_sequence_buffer) > max_stop_id_sequence_len: + if REPLACEMENT_CHAR in _tokenizer.decode(token): + continue + stop_condition = stopping_criteria( + tokens, + stop_id_sequences, + eos_token_id, + ) + if stop_condition.stop_met: + if stop_condition.trim_length: + tokens = tokens[: -stop_condition.trim_length] + break + # This is a workaround because the llama tokenizer emits spaces when decoding token by token. + generated_text = _tokenizer.decode(tokens) + next_chunk = generated_text[current_generated_text_index:] + current_generated_text_index = len(generated_text) + + response = response_creator(response_id, requested_model, next_chunk) + try: + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + stop_sequence_buffer = [] + except Exception as e: + print(e) + break + # check is there any remaining text to send + if stop_sequence_buffer: + generated_text = _tokenizer.decode(tokens) + next_chunk = generated_text[current_generated_text_index:] + response = response_creator(response_id, requested_model, next_chunk) + try: + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + except Exception as e: + print(e) + + self.wfile.write(f"data: [DONE]\n\n".encode()) + self.wfile.flush() + + def handle_chat_completions(self, post_data): body = json.loads(post_data.decode("utf-8")) chat_id = f"chatcmpl-{uuid.uuid4()}" if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: @@ -223,91 +368,74 @@ class APIHandler(BaseHTTPRequestHandler): temperature = body.get("temperature", 1.0) top_p = body.get("top_p", 1.0) if not stream: - tokens = [] - for token, _ in zip( - generate( - prompt, - _model, - temperature, - top_p=top_p, - ), - range(max_tokens), - ): - tokens.append(token) - stop_condition = stopping_criteria( - tokens, stop_id_sequences, eos_token_id - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - tokens = tokens[: -stop_condition.trim_length] - break - - text = _tokenizer.decode(tokens) - return create_response(chat_id, requested_model, prompt, tokens, text) - else: - self.send_response(200) - self.send_header("Content-type", "text/event-stream") - self.send_header("Cache-Control", "no-cache") - self.end_headers() - max_stop_id_sequence_len = ( - max(len(seq) for seq in stop_id_sequences) if stop_id_sequences else 0 + return self.generate_response( + prompt, + chat_id, + requested_model, + stop_id_sequences, + eos_token_id, + max_tokens, + temperature, + top_p, + create_chat_response, + ) + else: + self.hanlde_stream( + prompt, + chat_id, + requested_model, + stop_id_sequences, + eos_token_id, + max_tokens, + temperature, + top_p, + create_chat_chunk_response, ) - tokens = [] - current_generated_text_index = 0 - # Buffer to store the last `max_stop_id_sequence_len` tokens to check for stop conditions before writing to the stream. - stop_sequence_buffer = [] - REPLACEMENT_CHAR = "\ufffd" - for token, _ in zip( - generate( - prompt, - _model, - temperature, - top_p=top_p, - ), - range(max_tokens), - ): - tokens.append(token) - stop_sequence_buffer.append(token) - if len(stop_sequence_buffer) > max_stop_id_sequence_len: - if REPLACEMENT_CHAR in _tokenizer.decode(token): - continue - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - tokens = tokens[: -stop_condition.trim_length] - break - # This is a workaround because the llama tokenizer emits spaces when decoding token by token. - generated_text = _tokenizer.decode(tokens) - next_chunk = generated_text[current_generated_text_index:] - current_generated_text_index = len(generated_text) - response = create_chunk_response( - chat_id, requested_model, next_chunk - ) - try: - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - stop_sequence_buffer = [] - except Exception as e: - print(e) - break - # check is there any remaining text to send - if stop_sequence_buffer: - generated_text = _tokenizer.decode(tokens) - next_chunk = generated_text[current_generated_text_index:] - response = create_chunk_response(chat_id, requested_model, next_chunk) - try: - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - except Exception as e: - print(e) - - self.wfile.write(f"data: [DONE]\n\n".encode()) - self.wfile.flush() + def handle_completions(self, post_data): + body = json.loads(post_data.decode("utf-8")) + completion_id = f"cmpl-{uuid.uuid4()}" + prompt_text = body["prompt"] + prompt = _tokenizer.encode(prompt_text, return_tensors="np") + prompt = mx.array(prompt[0]) + stop_words = body.get("stop", []) + stop_words = [stop_words] if isinstance(stop_words, str) else stop_words + stop_id_sequences = [ + _tokenizer.encode(stop_word, return_tensors="np", add_special_tokens=False)[ + 0 + ] + for stop_word in stop_words + ] + eos_token_id = _tokenizer.eos_token_id + max_tokens = body.get("max_tokens", 100) + stream = body.get("stream", False) + requested_model = body.get("model", "default_model") + temperature = body.get("temperature", 1.0) + top_p = body.get("top_p", 1.0) + if not stream: + return self.generate_response( + prompt, + completion_id, + requested_model, + stop_id_sequences, + eos_token_id, + max_tokens, + temperature, + top_p, + create_completion_response, + ) + else: + self.hanlde_stream( + prompt, + completion_id, + requested_model, + stop_id_sequences, + eos_token_id, + max_tokens, + temperature, + top_p, + create_completion_chunk_response, + ) def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):