diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 482ee00c..e765bb5e 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -10,15 +10,10 @@ from typing import List, Literal, NamedTuple, Optional, Union import mlx.core as mx import mlx.nn as nn -from transformers import PreTrainedTokenizer +from .tokenizer_utils import TokenizerWrapper from .utils import generate_step, load -MODEL: nn.Module -TOKENIZER: PreTrainedTokenizer - -SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}" - class StopCondition(NamedTuple): stop_met: bool @@ -77,10 +72,12 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): class APIHandler(BaseHTTPRequestHandler): - def __init__(self, *args, **kwargs): + def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs): """ Create static request specific metadata """ + self.model = model + self.tokenizer = tokenizer self.created = int(time.time()) super().__init__(*args, **kwargs) @@ -136,7 +133,7 @@ class APIHandler(BaseHTTPRequestHandler): stop_words = self.body.get("stop", []) stop_words = [stop_words] if isinstance(stop_words, str) else stop_words stop_id_sequences = [ - TOKENIZER.encode(stop_word, add_special_tokens=False) + self.tokenizer.encode(stop_word, add_special_tokens=False) for stop_word in stop_words ] @@ -183,7 +180,7 @@ class APIHandler(BaseHTTPRequestHandler): # Static response response = { "id": self.request_id, - "system_fingerprint": SYSTEM_FINGERPRINT, + "system_fingerprint": f"fp_{uuid.uuid4()}", "object": self.object_type, "model": self.requested_model, "created": self.created, @@ -237,12 +234,15 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ + detokenizer = self.tokenizer.detokenizer + detokenizer.reset() tokens = [] finish_reason = "length" + stop_sequence_suffix = None for (token, _), _ in zip( generate_step( prompt=prompt, - model=MODEL, + model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, @@ -250,18 +250,25 @@ class APIHandler(BaseHTTPRequestHandler): ), range(self.max_tokens), ): - token = token.item() + detokenizer.add_token(token) tokens.append(token) stop_condition = stopping_criteria( - tokens, stop_id_sequences, TOKENIZER.eos_token_id + tokens, stop_id_sequences, self.tokenizer.eos_token_id ) if stop_condition.stop_met: finish_reason = "stop" if stop_condition.trim_length: - tokens = tokens[: -stop_condition.trim_length] + stop_sequence_suffix = self.tokenizer.decode( + tokens[-stop_condition.trim_length :] + ) break - text = TOKENIZER.decode(tokens) + detokenizer.finalize() + text = ( + detokenizer.text + if stop_sequence_suffix is None + else detokenizer.text[: -len(stop_sequence_suffix)] + ) response = self.generate_response(text, finish_reason, len(prompt), len(tokens)) response_json = json.dumps(response).encode() @@ -289,18 +296,19 @@ class APIHandler(BaseHTTPRequestHandler): # No additional headers are needed, call end_headers self.end_headers() + detokenizer = self.tokenizer.detokenizer + detokenizer.reset() tokens = [] - current_generated_text_index = 0 max_stop_id_sequence_len = len(max(stop_id_sequences, default=[])) # Buffer to store the last `max_stop_id_sequence_len` tokens # to check for stop conditions before writing to the stream. stop_sequence_buffer = [] - + stop_sequence_suffix = None for (token, _), _ in zip( generate_step( prompt=prompt, - model=MODEL, + model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, @@ -308,7 +316,7 @@ class APIHandler(BaseHTTPRequestHandler): ), range(self.max_tokens), ): - token = token.item() + detokenizer.add_token(token) tokens.append(token) stop_sequence_buffer.append(token) @@ -316,26 +324,20 @@ class APIHandler(BaseHTTPRequestHandler): if len(stop_sequence_buffer) < max_stop_id_sequence_len: continue - # "\ufffd" is used to indicate to the tokenizer, that subsequent characters - # should be combined into a single unicode character - if "\ufffd" in TOKENIZER.decode(token): - continue - stop_condition = stopping_criteria( tokens, stop_id_sequences, - TOKENIZER.eos_token_id, + self.tokenizer.eos_token_id, ) if stop_condition.stop_met: if stop_condition.trim_length: - tokens = tokens[: -stop_condition.trim_length] + stop_sequence_suffix = self.tokenizer.decode( + tokens[-stop_condition.trim_length :] + ) break - # Workaround for llama tokenizer emitting spaces when decoding token by token. - generated_text = TOKENIZER.decode(tokens) - new_text = generated_text[current_generated_text_index:] - current_generated_text_index = len(generated_text) - + detokenizer.finalize() + new_text = detokenizer.last_segment response = self.generate_response(new_text, None) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() @@ -343,8 +345,12 @@ class APIHandler(BaseHTTPRequestHandler): # check is there any remaining text to send if stop_sequence_buffer: - generated_text = TOKENIZER.decode(tokens) - next_chunk = generated_text[current_generated_text_index:] + detokenizer.finalize() + next_chunk = ( + detokenizer.last_segment + if stop_sequence_suffix is None + else detokenizer.last_segment[: -len(stop_sequence_suffix)] + ) response = self.generate_response(next_chunk, "length") self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) @@ -369,15 +375,18 @@ class APIHandler(BaseHTTPRequestHandler): "chat.completions.chunk" if self.stream else "chat.completions" ) - if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template: - prompt = TOKENIZER.apply_chat_template( + if ( + hasattr(self.tokenizer, "apply_chat_template") + and self.tokenizer.chat_template + ): + prompt = self.tokenizer.apply_chat_template( body["messages"], tokenize=True, add_generation_prompt=True, ) else: prompt = convert_chat(body["messages"], body.get("role_mapping")) - prompt = TOKENIZER.encode(prompt) + prompt = self.tokenizer.encode(prompt) return mx.array(prompt) @@ -394,13 +403,24 @@ class APIHandler(BaseHTTPRequestHandler): assert "prompt" in self.body, "Request did not contain a prompt" prompt_text = self.body["prompt"] - prompt = TOKENIZER.encode(prompt_text) + + prompt = self.tokenizer.encode(prompt_text) return mx.array(prompt) -def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): +def run( + host: str, + port: int, + model: nn.Module, + tokenizer: TokenizerWrapper, + server_class=HTTPServer, + handler_class=APIHandler, +): server_address = (host, port) - httpd = server_class(server_address, handler_class) + httpd = server_class( + server_address, + lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs), + ) warnings.warn( "mlx_lm.server is not recommended for production as " "it only implements basic security checks." @@ -444,11 +464,10 @@ def main(): # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} - MODEL, TOKENIZER = load( + model, tokenizer = load( args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config ) - - run(args.host, args.port) + run(args.host, args.port, model, tokenizer) if __name__ == "__main__": diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d5a03270..be273e67 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -360,7 +360,7 @@ def load( tokenizer_config={}, adapter_path: Optional[str] = None, lazy: bool = False, -) -> Tuple[nn.Module, PreTrainedTokenizer]: +) -> Tuple[nn.Module, TokenizerWrapper]: """ Load the model and tokenizer from a given path or a huggingface repository. @@ -374,7 +374,7 @@ def load( loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` Returns: - Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. + Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. Raises: FileNotFoundError: If config file or safetensors are not found. diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py new file mode 100644 index 00000000..998ad1c7 --- /dev/null +++ b/llms/tests/test_server.py @@ -0,0 +1,76 @@ +import http +import threading +import unittest + +import requests +from mlx_lm.server import APIHandler +from mlx_lm.utils import load + + +class TestServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + cls.model, cls.tokenizer = load(HF_MODEL_PATH) + + cls.server_address = ("localhost", 0) + cls.httpd = http.server.HTTPServer( + cls.server_address, + lambda *args, **kwargs: APIHandler( + cls.model, cls.tokenizer, *args, **kwargs + ), + ) + cls.port = cls.httpd.server_port + cls.server_thread = threading.Thread(target=cls.httpd.serve_forever) + cls.server_thread.daemon = True + cls.server_thread.start() + + @classmethod + def tearDownClass(cls): + cls.httpd.shutdown() + cls.httpd.server_close() + cls.server_thread.join() + + def test_handle_completions(self): + url = f"http://localhost:{self.port}/v1/completions" + + post_data = { + "model": "default_model", + "prompt": "Once upon a time", + "max_tokens": 10, + "temperature": 0.5, + "top_p": 0.9, + "repetition_penalty": 1.1, + "repetition_context_size": 20, + "stop": "stop sequence", + } + + response = requests.post(url, json=post_data) + + response_body = response.text + + self.assertIn("id", response_body) + self.assertIn("choices", response_body) + + def test_handle_chat_completions(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + chat_post_data = { + "model": "chat_model", + "max_tokens": 10, + "temperature": 0.7, + "top_p": 0.85, + "repetition_penalty": 1.2, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + response = requests.post(url, json=chat_post_data) + response_body = response.text + self.assertIn("id", response_body) + self.assertIn("choices", response_body) + + +if __name__ == "__main__": + unittest.main()