fix(mlx-lm): broken server.py (#690)

* fix server.py

* fix var referenced before assignment

* add test

* clean up
This commit is contained in:
Anchen 2024-04-19 07:26:18 +10:00 committed by GitHub
parent 35206806ac
commit f5f189e48a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 138 additions and 43 deletions

View File

@ -10,15 +10,10 @@ from typing import List, Literal, NamedTuple, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from transformers import PreTrainedTokenizer
from .tokenizer_utils import TokenizerWrapper
from .utils import generate_step, load from .utils import generate_step, load
MODEL: nn.Module
TOKENIZER: PreTrainedTokenizer
SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
class StopCondition(NamedTuple): class StopCondition(NamedTuple):
stop_met: bool stop_met: bool
@ -77,10 +72,12 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
""" """
Create static request specific metadata Create static request specific metadata
""" """
self.model = model
self.tokenizer = tokenizer
self.created = int(time.time()) self.created = int(time.time())
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -136,7 +133,7 @@ class APIHandler(BaseHTTPRequestHandler):
stop_words = self.body.get("stop", []) stop_words = self.body.get("stop", [])
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [ stop_id_sequences = [
TOKENIZER.encode(stop_word, add_special_tokens=False) self.tokenizer.encode(stop_word, add_special_tokens=False)
for stop_word in stop_words for stop_word in stop_words
] ]
@ -183,7 +180,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response # Static response
response = { response = {
"id": self.request_id, "id": self.request_id,
"system_fingerprint": SYSTEM_FINGERPRINT, "system_fingerprint": f"fp_{uuid.uuid4()}",
"object": self.object_type, "object": self.object_type,
"model": self.requested_model, "model": self.requested_model,
"created": self.created, "created": self.created,
@ -237,12 +234,15 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences (List[List[int]]): stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function A list of stop words passed to the stopping_criteria function
""" """
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = [] tokens = []
finish_reason = "length" finish_reason = "length"
stop_sequence_suffix = None
for (token, _), _ in zip( for (token, _), _ in zip(
generate_step( generate_step(
prompt=prompt, prompt=prompt,
model=MODEL, model=self.model,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p, top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
@ -250,18 +250,25 @@ class APIHandler(BaseHTTPRequestHandler):
), ),
range(self.max_tokens), range(self.max_tokens),
): ):
token = token.item() detokenizer.add_token(token)
tokens.append(token) tokens.append(token)
stop_condition = stopping_criteria( stop_condition = stopping_criteria(
tokens, stop_id_sequences, TOKENIZER.eos_token_id tokens, stop_id_sequences, self.tokenizer.eos_token_id
) )
if stop_condition.stop_met: if stop_condition.stop_met:
finish_reason = "stop" finish_reason = "stop"
if stop_condition.trim_length: if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length] stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
break break
text = TOKENIZER.decode(tokens) detokenizer.finalize()
text = (
detokenizer.text
if stop_sequence_suffix is None
else detokenizer.text[: -len(stop_sequence_suffix)]
)
response = self.generate_response(text, finish_reason, len(prompt), len(tokens)) response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
response_json = json.dumps(response).encode() response_json = json.dumps(response).encode()
@ -289,18 +296,19 @@ class APIHandler(BaseHTTPRequestHandler):
# No additional headers are needed, call end_headers # No additional headers are needed, call end_headers
self.end_headers() self.end_headers()
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = [] tokens = []
current_generated_text_index = 0
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[])) max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
# Buffer to store the last `max_stop_id_sequence_len` tokens # Buffer to store the last `max_stop_id_sequence_len` tokens
# to check for stop conditions before writing to the stream. # to check for stop conditions before writing to the stream.
stop_sequence_buffer = [] stop_sequence_buffer = []
stop_sequence_suffix = None
for (token, _), _ in zip( for (token, _), _ in zip(
generate_step( generate_step(
prompt=prompt, prompt=prompt,
model=MODEL, model=self.model,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p, top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
@ -308,7 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
), ),
range(self.max_tokens), range(self.max_tokens),
): ):
token = token.item() detokenizer.add_token(token)
tokens.append(token) tokens.append(token)
stop_sequence_buffer.append(token) stop_sequence_buffer.append(token)
@ -316,26 +324,20 @@ class APIHandler(BaseHTTPRequestHandler):
if len(stop_sequence_buffer) < max_stop_id_sequence_len: if len(stop_sequence_buffer) < max_stop_id_sequence_len:
continue continue
# "\ufffd" is used to indicate to the tokenizer, that subsequent characters
# should be combined into a single unicode character
if "\ufffd" in TOKENIZER.decode(token):
continue
stop_condition = stopping_criteria( stop_condition = stopping_criteria(
tokens, tokens,
stop_id_sequences, stop_id_sequences,
TOKENIZER.eos_token_id, self.tokenizer.eos_token_id,
) )
if stop_condition.stop_met: if stop_condition.stop_met:
if stop_condition.trim_length: if stop_condition.trim_length:
tokens = tokens[: -stop_condition.trim_length] stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
break break
# Workaround for llama tokenizer emitting spaces when decoding token by token. detokenizer.finalize()
generated_text = TOKENIZER.decode(tokens) new_text = detokenizer.last_segment
new_text = generated_text[current_generated_text_index:]
current_generated_text_index = len(generated_text)
response = self.generate_response(new_text, None) response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush() self.wfile.flush()
@ -343,8 +345,12 @@ class APIHandler(BaseHTTPRequestHandler):
# check is there any remaining text to send # check is there any remaining text to send
if stop_sequence_buffer: if stop_sequence_buffer:
generated_text = TOKENIZER.decode(tokens) detokenizer.finalize()
next_chunk = generated_text[current_generated_text_index:] next_chunk = (
detokenizer.last_segment
if stop_sequence_suffix is None
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
)
response = self.generate_response(next_chunk, "length") response = self.generate_response(next_chunk, "length")
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
@ -369,15 +375,18 @@ class APIHandler(BaseHTTPRequestHandler):
"chat.completions.chunk" if self.stream else "chat.completions" "chat.completions.chunk" if self.stream else "chat.completions"
) )
if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template: if (
prompt = TOKENIZER.apply_chat_template( hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
prompt = self.tokenizer.apply_chat_template(
body["messages"], body["messages"],
tokenize=True, tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
) )
else: else:
prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = TOKENIZER.encode(prompt) prompt = self.tokenizer.encode(prompt)
return mx.array(prompt) return mx.array(prompt)
@ -394,13 +403,24 @@ class APIHandler(BaseHTTPRequestHandler):
assert "prompt" in self.body, "Request did not contain a prompt" assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"] prompt_text = self.body["prompt"]
prompt = TOKENIZER.encode(prompt_text)
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt) return mx.array(prompt)
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): def run(
host: str,
port: int,
model: nn.Module,
tokenizer: TokenizerWrapper,
server_class=HTTPServer,
handler_class=APIHandler,
):
server_address = (host, port) server_address = (host, port)
httpd = server_class(server_address, handler_class) httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
)
warnings.warn( warnings.warn(
"mlx_lm.server is not recommended for production as " "mlx_lm.server is not recommended for production as "
"it only implements basic security checks." "it only implements basic security checks."
@ -444,11 +464,10 @@ def main():
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
MODEL, TOKENIZER = load( model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
) )
run(args.host, args.port, model, tokenizer)
run(args.host, args.port)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -360,7 +360,7 @@ def load(
tokenizer_config={}, tokenizer_config={},
adapter_path: Optional[str] = None, adapter_path: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
) -> Tuple[nn.Module, PreTrainedTokenizer]: ) -> Tuple[nn.Module, TokenizerWrapper]:
""" """
Load the model and tokenizer from a given path or a huggingface repository. Load the model and tokenizer from a given path or a huggingface repository.
@ -374,7 +374,7 @@ def load(
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
Returns: Returns:
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Raises: Raises:
FileNotFoundError: If config file or safetensors are not found. FileNotFoundError: If config file or safetensors are not found.

76
llms/tests/test_server.py Normal file
View File

@ -0,0 +1,76 @@
import http
import threading
import unittest
import requests
from mlx_lm.server import APIHandler
from mlx_lm.utils import load
class TestServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
cls.server_address = ("localhost", 0)
cls.httpd = http.server.HTTPServer(
cls.server_address,
lambda *args, **kwargs: APIHandler(
cls.model, cls.tokenizer, *args, **kwargs
),
)
cls.port = cls.httpd.server_port
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
cls.server_thread.daemon = True
cls.server_thread.start()
@classmethod
def tearDownClass(cls):
cls.httpd.shutdown()
cls.httpd.server_close()
cls.server_thread.join()
def test_handle_completions(self):
url = f"http://localhost:{self.port}/v1/completions"
post_data = {
"model": "default_model",
"prompt": "Once upon a time",
"max_tokens": 10,
"temperature": 0.5,
"top_p": 0.9,
"repetition_penalty": 1.1,
"repetition_context_size": 20,
"stop": "stop sequence",
}
response = requests.post(url, json=post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_chat_completions(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
if __name__ == "__main__":
unittest.main()