From f5f189e48a426724ecd8ad8e161444922a2585db Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 19 Apr 2024 07:26:18 +1000
Subject: [PATCH] fix(mlx-lm): broken server.py (#690)
* fix server.py
* fix var referenced before assignment
* add test
* clean up
---
llms/mlx_lm/server.py | 101 ++++++++++++++++++++++----------------
llms/mlx_lm/utils.py | 4 +-
llms/tests/test_server.py | 76 ++++++++++++++++++++++++++++
3 files changed, 138 insertions(+), 43 deletions(-)
create mode 100644 llms/tests/test_server.py
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index 482ee00c..e765bb5e 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -10,15 +10,10 @@ from typing import List, Literal, NamedTuple, Optional, Union
import mlx.core as mx
import mlx.nn as nn
-from transformers import PreTrainedTokenizer
+from .tokenizer_utils import TokenizerWrapper
from .utils import generate_step, load
-MODEL: nn.Module
-TOKENIZER: PreTrainedTokenizer
-
-SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
-
class StopCondition(NamedTuple):
stop_met: bool
@@ -77,10 +72,12 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
class APIHandler(BaseHTTPRequestHandler):
- def __init__(self, *args, **kwargs):
+ def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
"""
Create static request specific metadata
"""
+ self.model = model
+ self.tokenizer = tokenizer
self.created = int(time.time())
super().__init__(*args, **kwargs)
@@ -136,7 +133,7 @@ class APIHandler(BaseHTTPRequestHandler):
stop_words = self.body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
- TOKENIZER.encode(stop_word, add_special_tokens=False)
+ self.tokenizer.encode(stop_word, add_special_tokens=False)
for stop_word in stop_words
]
@@ -183,7 +180,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response
response = {
"id": self.request_id,
- "system_fingerprint": SYSTEM_FINGERPRINT,
+ "system_fingerprint": f"fp_{uuid.uuid4()}",
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
@@ -237,12 +234,15 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function
"""
+ detokenizer = self.tokenizer.detokenizer
+ detokenizer.reset()
tokens = []
finish_reason = "length"
+ stop_sequence_suffix = None
for (token, _), _ in zip(
generate_step(
prompt=prompt,
- model=MODEL,
+ model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
@@ -250,18 +250,25 @@ class APIHandler(BaseHTTPRequestHandler):
),
range(self.max_tokens),
):
- token = token.item()
+ detokenizer.add_token(token)
tokens.append(token)
stop_condition = stopping_criteria(
- tokens, stop_id_sequences, TOKENIZER.eos_token_id
+ tokens, stop_id_sequences, self.tokenizer.eos_token_id
)
if stop_condition.stop_met:
finish_reason = "stop"
if stop_condition.trim_length:
- tokens = tokens[: -stop_condition.trim_length]
+ stop_sequence_suffix = self.tokenizer.decode(
+ tokens[-stop_condition.trim_length :]
+ )
break
- text = TOKENIZER.decode(tokens)
+ detokenizer.finalize()
+ text = (
+ detokenizer.text
+ if stop_sequence_suffix is None
+ else detokenizer.text[: -len(stop_sequence_suffix)]
+ )
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
response_json = json.dumps(response).encode()
@@ -289,18 +296,19 @@ class APIHandler(BaseHTTPRequestHandler):
# No additional headers are needed, call end_headers
self.end_headers()
+ detokenizer = self.tokenizer.detokenizer
+ detokenizer.reset()
tokens = []
- current_generated_text_index = 0
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
# Buffer to store the last `max_stop_id_sequence_len` tokens
# to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
-
+ stop_sequence_suffix = None
for (token, _), _ in zip(
generate_step(
prompt=prompt,
- model=MODEL,
+ model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
@@ -308,7 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
),
range(self.max_tokens),
):
- token = token.item()
+ detokenizer.add_token(token)
tokens.append(token)
stop_sequence_buffer.append(token)
@@ -316,26 +324,20 @@ class APIHandler(BaseHTTPRequestHandler):
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
continue
- # "\ufffd" is used to indicate to the tokenizer, that subsequent characters
- # should be combined into a single unicode character
- if "\ufffd" in TOKENIZER.decode(token):
- continue
-
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
- TOKENIZER.eos_token_id,
+ self.tokenizer.eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
- tokens = tokens[: -stop_condition.trim_length]
+ stop_sequence_suffix = self.tokenizer.decode(
+ tokens[-stop_condition.trim_length :]
+ )
break
- # Workaround for llama tokenizer emitting spaces when decoding token by token.
- generated_text = TOKENIZER.decode(tokens)
- new_text = generated_text[current_generated_text_index:]
- current_generated_text_index = len(generated_text)
-
+ detokenizer.finalize()
+ new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
@@ -343,8 +345,12 @@ class APIHandler(BaseHTTPRequestHandler):
# check is there any remaining text to send
if stop_sequence_buffer:
- generated_text = TOKENIZER.decode(tokens)
- next_chunk = generated_text[current_generated_text_index:]
+ detokenizer.finalize()
+ next_chunk = (
+ detokenizer.last_segment
+ if stop_sequence_suffix is None
+ else detokenizer.last_segment[: -len(stop_sequence_suffix)]
+ )
response = self.generate_response(next_chunk, "length")
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
@@ -369,15 +375,18 @@ class APIHandler(BaseHTTPRequestHandler):
"chat.completions.chunk" if self.stream else "chat.completions"
)
- if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template:
- prompt = TOKENIZER.apply_chat_template(
+ if (
+ hasattr(self.tokenizer, "apply_chat_template")
+ and self.tokenizer.chat_template
+ ):
+ prompt = self.tokenizer.apply_chat_template(
body["messages"],
tokenize=True,
add_generation_prompt=True,
)
else:
prompt = convert_chat(body["messages"], body.get("role_mapping"))
- prompt = TOKENIZER.encode(prompt)
+ prompt = self.tokenizer.encode(prompt)
return mx.array(prompt)
@@ -394,13 +403,24 @@ class APIHandler(BaseHTTPRequestHandler):
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
- prompt = TOKENIZER.encode(prompt_text)
+
+ prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
-def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
+def run(
+ host: str,
+ port: int,
+ model: nn.Module,
+ tokenizer: TokenizerWrapper,
+ server_class=HTTPServer,
+ handler_class=APIHandler,
+):
server_address = (host, port)
- httpd = server_class(server_address, handler_class)
+ httpd = server_class(
+ server_address,
+ lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
+ )
warnings.warn(
"mlx_lm.server is not recommended for production as "
"it only implements basic security checks."
@@ -444,11 +464,10 @@ def main():
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
- MODEL, TOKENIZER = load(
+ model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
-
- run(args.host, args.port)
+ run(args.host, args.port, model, tokenizer)
if __name__ == "__main__":
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index d5a03270..be273e67 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -360,7 +360,7 @@ def load(
tokenizer_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
-) -> Tuple[nn.Module, PreTrainedTokenizer]:
+) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
@@ -374,7 +374,7 @@ def load(
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns:
- Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
+ Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py
new file mode 100644
index 00000000..998ad1c7
--- /dev/null
+++ b/llms/tests/test_server.py
@@ -0,0 +1,76 @@
+import http
+import threading
+import unittest
+
+import requests
+from mlx_lm.server import APIHandler
+from mlx_lm.utils import load
+
+
+class TestServer(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
+
+ cls.model, cls.tokenizer = load(HF_MODEL_PATH)
+
+ cls.server_address = ("localhost", 0)
+ cls.httpd = http.server.HTTPServer(
+ cls.server_address,
+ lambda *args, **kwargs: APIHandler(
+ cls.model, cls.tokenizer, *args, **kwargs
+ ),
+ )
+ cls.port = cls.httpd.server_port
+ cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
+ cls.server_thread.daemon = True
+ cls.server_thread.start()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.httpd.shutdown()
+ cls.httpd.server_close()
+ cls.server_thread.join()
+
+ def test_handle_completions(self):
+ url = f"http://localhost:{self.port}/v1/completions"
+
+ post_data = {
+ "model": "default_model",
+ "prompt": "Once upon a time",
+ "max_tokens": 10,
+ "temperature": 0.5,
+ "top_p": 0.9,
+ "repetition_penalty": 1.1,
+ "repetition_context_size": 20,
+ "stop": "stop sequence",
+ }
+
+ response = requests.post(url, json=post_data)
+
+ response_body = response.text
+
+ self.assertIn("id", response_body)
+ self.assertIn("choices", response_body)
+
+ def test_handle_chat_completions(self):
+ url = f"http://localhost:{self.port}/v1/chat/completions"
+ chat_post_data = {
+ "model": "chat_model",
+ "max_tokens": 10,
+ "temperature": 0.7,
+ "top_p": 0.85,
+ "repetition_penalty": 1.2,
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello!"},
+ ],
+ }
+ response = requests.post(url, json=chat_post_data)
+ response_body = response.text
+ self.assertIn("id", response_body)
+ self.assertIn("choices", response_body)
+
+
+if __name__ == "__main__":
+ unittest.main()