mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
fix(mlx-lm): broken server.py (#690)
* fix server.py * fix var referenced before assignment * add test * clean up
This commit is contained in:
parent
35206806ac
commit
f5f189e48a
@ -10,15 +10,10 @@ from typing import List, Literal, NamedTuple, Optional, Union
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
|
|
||||||
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
|
|
||||||
MODEL: nn.Module
|
|
||||||
TOKENIZER: PreTrainedTokenizer
|
|
||||||
|
|
||||||
SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
|
|
||||||
|
|
||||||
|
|
||||||
class StopCondition(NamedTuple):
|
class StopCondition(NamedTuple):
|
||||||
stop_met: bool
|
stop_met: bool
|
||||||
@ -77,10 +72,12 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
|||||||
|
|
||||||
|
|
||||||
class APIHandler(BaseHTTPRequestHandler):
|
class APIHandler(BaseHTTPRequestHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create static request specific metadata
|
Create static request specific metadata
|
||||||
"""
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
self.created = int(time.time())
|
self.created = int(time.time())
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -136,7 +133,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_words = self.body.get("stop", [])
|
stop_words = self.body.get("stop", [])
|
||||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||||
stop_id_sequences = [
|
stop_id_sequences = [
|
||||||
TOKENIZER.encode(stop_word, add_special_tokens=False)
|
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
||||||
for stop_word in stop_words
|
for stop_word in stop_words
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -183,7 +180,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
# Static response
|
# Static response
|
||||||
response = {
|
response = {
|
||||||
"id": self.request_id,
|
"id": self.request_id,
|
||||||
"system_fingerprint": SYSTEM_FINGERPRINT,
|
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||||
"object": self.object_type,
|
"object": self.object_type,
|
||||||
"model": self.requested_model,
|
"model": self.requested_model,
|
||||||
"created": self.created,
|
"created": self.created,
|
||||||
@ -237,12 +234,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_id_sequences (List[List[int]]):
|
stop_id_sequences (List[List[int]]):
|
||||||
A list of stop words passed to the stopping_criteria function
|
A list of stop words passed to the stopping_criteria function
|
||||||
"""
|
"""
|
||||||
|
detokenizer = self.tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
tokens = []
|
tokens = []
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
|
stop_sequence_suffix = None
|
||||||
for (token, _), _ in zip(
|
for (token, _), _ in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=MODEL,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
@ -250,18 +250,25 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
),
|
),
|
||||||
range(self.max_tokens),
|
range(self.max_tokens),
|
||||||
):
|
):
|
||||||
token = token.item()
|
detokenizer.add_token(token)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
stop_condition = stopping_criteria(
|
stop_condition = stopping_criteria(
|
||||||
tokens, stop_id_sequences, TOKENIZER.eos_token_id
|
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||||
)
|
)
|
||||||
if stop_condition.stop_met:
|
if stop_condition.stop_met:
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
if stop_condition.trim_length:
|
if stop_condition.trim_length:
|
||||||
tokens = tokens[: -stop_condition.trim_length]
|
stop_sequence_suffix = self.tokenizer.decode(
|
||||||
|
tokens[-stop_condition.trim_length :]
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
text = TOKENIZER.decode(tokens)
|
detokenizer.finalize()
|
||||||
|
text = (
|
||||||
|
detokenizer.text
|
||||||
|
if stop_sequence_suffix is None
|
||||||
|
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||||
|
)
|
||||||
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
|
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
|
||||||
|
|
||||||
response_json = json.dumps(response).encode()
|
response_json = json.dumps(response).encode()
|
||||||
@ -289,18 +296,19 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
# No additional headers are needed, call end_headers
|
# No additional headers are needed, call end_headers
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
|
detokenizer = self.tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
tokens = []
|
tokens = []
|
||||||
current_generated_text_index = 0
|
|
||||||
|
|
||||||
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
||||||
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
||||||
# to check for stop conditions before writing to the stream.
|
# to check for stop conditions before writing to the stream.
|
||||||
stop_sequence_buffer = []
|
stop_sequence_buffer = []
|
||||||
|
stop_sequence_suffix = None
|
||||||
for (token, _), _ in zip(
|
for (token, _), _ in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=MODEL,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
@ -308,7 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
),
|
),
|
||||||
range(self.max_tokens),
|
range(self.max_tokens),
|
||||||
):
|
):
|
||||||
token = token.item()
|
detokenizer.add_token(token)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
stop_sequence_buffer.append(token)
|
stop_sequence_buffer.append(token)
|
||||||
|
|
||||||
@ -316,26 +324,20 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# "\ufffd" is used to indicate to the tokenizer, that subsequent characters
|
|
||||||
# should be combined into a single unicode character
|
|
||||||
if "\ufffd" in TOKENIZER.decode(token):
|
|
||||||
continue
|
|
||||||
|
|
||||||
stop_condition = stopping_criteria(
|
stop_condition = stopping_criteria(
|
||||||
tokens,
|
tokens,
|
||||||
stop_id_sequences,
|
stop_id_sequences,
|
||||||
TOKENIZER.eos_token_id,
|
self.tokenizer.eos_token_id,
|
||||||
)
|
)
|
||||||
if stop_condition.stop_met:
|
if stop_condition.stop_met:
|
||||||
if stop_condition.trim_length:
|
if stop_condition.trim_length:
|
||||||
tokens = tokens[: -stop_condition.trim_length]
|
stop_sequence_suffix = self.tokenizer.decode(
|
||||||
|
tokens[-stop_condition.trim_length :]
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Workaround for llama tokenizer emitting spaces when decoding token by token.
|
detokenizer.finalize()
|
||||||
generated_text = TOKENIZER.decode(tokens)
|
new_text = detokenizer.last_segment
|
||||||
new_text = generated_text[current_generated_text_index:]
|
|
||||||
current_generated_text_index = len(generated_text)
|
|
||||||
|
|
||||||
response = self.generate_response(new_text, None)
|
response = self.generate_response(new_text, None)
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
@ -343,8 +345,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# check is there any remaining text to send
|
# check is there any remaining text to send
|
||||||
if stop_sequence_buffer:
|
if stop_sequence_buffer:
|
||||||
generated_text = TOKENIZER.decode(tokens)
|
detokenizer.finalize()
|
||||||
next_chunk = generated_text[current_generated_text_index:]
|
next_chunk = (
|
||||||
|
detokenizer.last_segment
|
||||||
|
if stop_sequence_suffix is None
|
||||||
|
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
|
||||||
|
)
|
||||||
response = self.generate_response(next_chunk, "length")
|
response = self.generate_response(next_chunk, "length")
|
||||||
|
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
@ -369,15 +375,18 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
"chat.completions.chunk" if self.stream else "chat.completions"
|
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template:
|
if (
|
||||||
prompt = TOKENIZER.apply_chat_template(
|
hasattr(self.tokenizer, "apply_chat_template")
|
||||||
|
and self.tokenizer.chat_template
|
||||||
|
):
|
||||||
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
body["messages"],
|
body["messages"],
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||||
prompt = TOKENIZER.encode(prompt)
|
prompt = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
return mx.array(prompt)
|
return mx.array(prompt)
|
||||||
|
|
||||||
@ -394,13 +403,24 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||||
prompt_text = self.body["prompt"]
|
prompt_text = self.body["prompt"]
|
||||||
prompt = TOKENIZER.encode(prompt_text)
|
|
||||||
|
prompt = self.tokenizer.encode(prompt_text)
|
||||||
return mx.array(prompt)
|
return mx.array(prompt)
|
||||||
|
|
||||||
|
|
||||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
def run(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
server_class=HTTPServer,
|
||||||
|
handler_class=APIHandler,
|
||||||
|
):
|
||||||
server_address = (host, port)
|
server_address = (host, port)
|
||||||
httpd = server_class(server_address, handler_class)
|
httpd = server_class(
|
||||||
|
server_address,
|
||||||
|
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
|
||||||
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"mlx_lm.server is not recommended for production as "
|
"mlx_lm.server is not recommended for production as "
|
||||||
"it only implements basic security checks."
|
"it only implements basic security checks."
|
||||||
@ -444,11 +464,10 @@ def main():
|
|||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
|
||||||
MODEL, TOKENIZER = load(
|
model, tokenizer = load(
|
||||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||||
)
|
)
|
||||||
|
run(args.host, args.port, model, tokenizer)
|
||||||
run(args.host, args.port)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -360,7 +360,7 @@ def load(
|
|||||||
tokenizer_config={},
|
tokenizer_config={},
|
||||||
adapter_path: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
|
|
||||||
@ -374,7 +374,7 @@ def load(
|
|||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If config file or safetensors are not found.
|
FileNotFoundError: If config file or safetensors are not found.
|
||||||
|
76
llms/tests/test_server.py
Normal file
76
llms/tests/test_server.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import http
|
||||||
|
import threading
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from mlx_lm.server import APIHandler
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
|
||||||
|
|
||||||
|
class TestServer(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
|
||||||
|
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||||
|
|
||||||
|
cls.server_address = ("localhost", 0)
|
||||||
|
cls.httpd = http.server.HTTPServer(
|
||||||
|
cls.server_address,
|
||||||
|
lambda *args, **kwargs: APIHandler(
|
||||||
|
cls.model, cls.tokenizer, *args, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
cls.port = cls.httpd.server_port
|
||||||
|
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
|
||||||
|
cls.server_thread.daemon = True
|
||||||
|
cls.server_thread.start()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.httpd.shutdown()
|
||||||
|
cls.httpd.server_close()
|
||||||
|
cls.server_thread.join()
|
||||||
|
|
||||||
|
def test_handle_completions(self):
|
||||||
|
url = f"http://localhost:{self.port}/v1/completions"
|
||||||
|
|
||||||
|
post_data = {
|
||||||
|
"model": "default_model",
|
||||||
|
"prompt": "Once upon a time",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"repetition_context_size": 20,
|
||||||
|
"stop": "stop sequence",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json=post_data)
|
||||||
|
|
||||||
|
response_body = response.text
|
||||||
|
|
||||||
|
self.assertIn("id", response_body)
|
||||||
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
|
def test_handle_chat_completions(self):
|
||||||
|
url = f"http://localhost:{self.port}/v1/chat/completions"
|
||||||
|
chat_post_data = {
|
||||||
|
"model": "chat_model",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.85,
|
||||||
|
"repetition_penalty": 1.2,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
response = requests.post(url, json=chat_post_data)
|
||||||
|
response_body = response.text
|
||||||
|
self.assertIn("id", response_body)
|
||||||
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user