mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix(mlx-lm): broken server.py (#690)
* fix server.py * fix var referenced before assignment * add test * clean up
This commit is contained in:
parent
35206806ac
commit
f5f189e48a
@ -10,15 +10,10 @@ from typing import List, Literal, NamedTuple, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .utils import generate_step, load
|
||||
|
||||
MODEL: nn.Module
|
||||
TOKENIZER: PreTrainedTokenizer
|
||||
|
||||
SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
|
||||
|
||||
|
||||
class StopCondition(NamedTuple):
|
||||
stop_met: bool
|
||||
@ -77,10 +72,12 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
|
||||
"""
|
||||
Create static request specific metadata
|
||||
"""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.created = int(time.time())
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -136,7 +133,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_words = self.body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
TOKENIZER.encode(stop_word, add_special_tokens=False)
|
||||
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
||||
for stop_word in stop_words
|
||||
]
|
||||
|
||||
@ -183,7 +180,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Static response
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": SYSTEM_FINGERPRINT,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"object": self.object_type,
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
@ -237,12 +234,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_id_sequences (List[List[int]]):
|
||||
A list of stop words passed to the stopping_criteria function
|
||||
"""
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
tokens = []
|
||||
finish_reason = "length"
|
||||
stop_sequence_suffix = None
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=MODEL,
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
@ -250,18 +250,25 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
detokenizer.add_token(token)
|
||||
tokens.append(token)
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, TOKENIZER.eos_token_id
|
||||
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
finish_reason = "stop"
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
stop_sequence_suffix = self.tokenizer.decode(
|
||||
tokens[-stop_condition.trim_length :]
|
||||
)
|
||||
break
|
||||
|
||||
text = TOKENIZER.decode(tokens)
|
||||
detokenizer.finalize()
|
||||
text = (
|
||||
detokenizer.text
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
@ -289,18 +296,19 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# No additional headers are needed, call end_headers
|
||||
self.end_headers()
|
||||
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
tokens = []
|
||||
current_generated_text_index = 0
|
||||
|
||||
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
||||
# to check for stop conditions before writing to the stream.
|
||||
stop_sequence_buffer = []
|
||||
|
||||
stop_sequence_suffix = None
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=MODEL,
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
@ -308,7 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
detokenizer.add_token(token)
|
||||
tokens.append(token)
|
||||
stop_sequence_buffer.append(token)
|
||||
|
||||
@ -316,26 +324,20 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
||||
continue
|
||||
|
||||
# "\ufffd" is used to indicate to the tokenizer, that subsequent characters
|
||||
# should be combined into a single unicode character
|
||||
if "\ufffd" in TOKENIZER.decode(token):
|
||||
continue
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
stop_id_sequences,
|
||||
TOKENIZER.eos_token_id,
|
||||
self.tokenizer.eos_token_id,
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
stop_sequence_suffix = self.tokenizer.decode(
|
||||
tokens[-stop_condition.trim_length :]
|
||||
)
|
||||
break
|
||||
|
||||
# Workaround for llama tokenizer emitting spaces when decoding token by token.
|
||||
generated_text = TOKENIZER.decode(tokens)
|
||||
new_text = generated_text[current_generated_text_index:]
|
||||
current_generated_text_index = len(generated_text)
|
||||
|
||||
detokenizer.finalize()
|
||||
new_text = detokenizer.last_segment
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
@ -343,8 +345,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
# check is there any remaining text to send
|
||||
if stop_sequence_buffer:
|
||||
generated_text = TOKENIZER.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
detokenizer.finalize()
|
||||
next_chunk = (
|
||||
detokenizer.last_segment
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
response = self.generate_response(next_chunk, "length")
|
||||
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
@ -369,15 +375,18 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||
)
|
||||
|
||||
if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template:
|
||||
prompt = TOKENIZER.apply_chat_template(
|
||||
if (
|
||||
hasattr(self.tokenizer, "apply_chat_template")
|
||||
and self.tokenizer.chat_template
|
||||
):
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = TOKENIZER.encode(prompt)
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
|
||||
return mx.array(prompt)
|
||||
|
||||
@ -394,13 +403,24 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||
prompt_text = self.body["prompt"]
|
||||
prompt = TOKENIZER.encode(prompt_text)
|
||||
|
||||
prompt = self.tokenizer.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
|
||||
|
||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||
def run(
|
||||
host: str,
|
||||
port: int,
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
server_class=HTTPServer,
|
||||
handler_class=APIHandler,
|
||||
):
|
||||
server_address = (host, port)
|
||||
httpd = server_class(server_address, handler_class)
|
||||
httpd = server_class(
|
||||
server_address,
|
||||
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
|
||||
)
|
||||
warnings.warn(
|
||||
"mlx_lm.server is not recommended for production as "
|
||||
"it only implements basic security checks."
|
||||
@ -444,11 +464,10 @@ def main():
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
|
||||
MODEL, TOKENIZER = load(
|
||||
model, tokenizer = load(
|
||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
run(args.host, args.port)
|
||||
run(args.host, args.port, model, tokenizer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -360,7 +360,7 @@ def load(
|
||||
tokenizer_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
@ -374,7 +374,7 @@ def load(
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
Returns:
|
||||
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
|
76
llms/tests/test_server.py
Normal file
76
llms/tests/test_server.py
Normal file
@ -0,0 +1,76 @@
|
||||
import http
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from mlx_lm.server import APIHandler
|
||||
from mlx_lm.utils import load
|
||||
|
||||
|
||||
class TestServer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
cls.server_address = ("localhost", 0)
|
||||
cls.httpd = http.server.HTTPServer(
|
||||
cls.server_address,
|
||||
lambda *args, **kwargs: APIHandler(
|
||||
cls.model, cls.tokenizer, *args, **kwargs
|
||||
),
|
||||
)
|
||||
cls.port = cls.httpd.server_port
|
||||
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
|
||||
cls.server_thread.daemon = True
|
||||
cls.server_thread.start()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.httpd.shutdown()
|
||||
cls.httpd.server_close()
|
||||
cls.server_thread.join()
|
||||
|
||||
def test_handle_completions(self):
|
||||
url = f"http://localhost:{self.port}/v1/completions"
|
||||
|
||||
post_data = {
|
||||
"model": "default_model",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.1,
|
||||
"repetition_context_size": 20,
|
||||
"stop": "stop sequence",
|
||||
}
|
||||
|
||||
response = requests.post(url, json=post_data)
|
||||
|
||||
response_body = response.text
|
||||
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_chat_completions(self):
|
||||
url = f"http://localhost:{self.port}/v1/chat/completions"
|
||||
chat_post_data = {
|
||||
"model": "chat_model",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.85,
|
||||
"repetition_penalty": 1.2,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
}
|
||||
response = requests.post(url, json=chat_post_data)
|
||||
response_body = response.text
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user