mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
fix(mlx-lm): broken server.py (#690)
* fix server.py * fix var referenced before assignment * add test * clean up
This commit is contained in:
@@ -10,15 +10,10 @@ from typing import List, Literal, NamedTuple, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .utils import generate_step, load
|
||||
|
||||
MODEL: nn.Module
|
||||
TOKENIZER: PreTrainedTokenizer
|
||||
|
||||
SYSTEM_FINGERPRINT: str = f"fp_{uuid.uuid4()}"
|
||||
|
||||
|
||||
class StopCondition(NamedTuple):
|
||||
stop_met: bool
|
||||
@@ -77,10 +72,12 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
|
||||
"""
|
||||
Create static request specific metadata
|
||||
"""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.created = int(time.time())
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -136,7 +133,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_words = self.body.get("stop", [])
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
TOKENIZER.encode(stop_word, add_special_tokens=False)
|
||||
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
||||
for stop_word in stop_words
|
||||
]
|
||||
|
||||
@@ -183,7 +180,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Static response
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": SYSTEM_FINGERPRINT,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"object": self.object_type,
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
@@ -237,12 +234,15 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_id_sequences (List[List[int]]):
|
||||
A list of stop words passed to the stopping_criteria function
|
||||
"""
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
tokens = []
|
||||
finish_reason = "length"
|
||||
stop_sequence_suffix = None
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=MODEL,
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
@@ -250,18 +250,25 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
detokenizer.add_token(token)
|
||||
tokens.append(token)
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, TOKENIZER.eos_token_id
|
||||
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
finish_reason = "stop"
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
stop_sequence_suffix = self.tokenizer.decode(
|
||||
tokens[-stop_condition.trim_length :]
|
||||
)
|
||||
break
|
||||
|
||||
text = TOKENIZER.decode(tokens)
|
||||
detokenizer.finalize()
|
||||
text = (
|
||||
detokenizer.text
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
@@ -289,18 +296,19 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# No additional headers are needed, call end_headers
|
||||
self.end_headers()
|
||||
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
tokens = []
|
||||
current_generated_text_index = 0
|
||||
|
||||
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
||||
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
||||
# to check for stop conditions before writing to the stream.
|
||||
stop_sequence_buffer = []
|
||||
|
||||
stop_sequence_suffix = None
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
model=MODEL,
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
@@ -308,7 +316,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
detokenizer.add_token(token)
|
||||
tokens.append(token)
|
||||
stop_sequence_buffer.append(token)
|
||||
|
||||
@@ -316,26 +324,20 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
||||
continue
|
||||
|
||||
# "\ufffd" is used to indicate to the tokenizer, that subsequent characters
|
||||
# should be combined into a single unicode character
|
||||
if "\ufffd" in TOKENIZER.decode(token):
|
||||
continue
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
stop_id_sequences,
|
||||
TOKENIZER.eos_token_id,
|
||||
self.tokenizer.eos_token_id,
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
tokens = tokens[: -stop_condition.trim_length]
|
||||
stop_sequence_suffix = self.tokenizer.decode(
|
||||
tokens[-stop_condition.trim_length :]
|
||||
)
|
||||
break
|
||||
|
||||
# Workaround for llama tokenizer emitting spaces when decoding token by token.
|
||||
generated_text = TOKENIZER.decode(tokens)
|
||||
new_text = generated_text[current_generated_text_index:]
|
||||
current_generated_text_index = len(generated_text)
|
||||
|
||||
detokenizer.finalize()
|
||||
new_text = detokenizer.last_segment
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
@@ -343,8 +345,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
# check is there any remaining text to send
|
||||
if stop_sequence_buffer:
|
||||
generated_text = TOKENIZER.decode(tokens)
|
||||
next_chunk = generated_text[current_generated_text_index:]
|
||||
detokenizer.finalize()
|
||||
next_chunk = (
|
||||
detokenizer.last_segment
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
response = self.generate_response(next_chunk, "length")
|
||||
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
@@ -369,15 +375,18 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||
)
|
||||
|
||||
if hasattr(TOKENIZER, "apply_chat_template") and TOKENIZER.chat_template:
|
||||
prompt = TOKENIZER.apply_chat_template(
|
||||
if (
|
||||
hasattr(self.tokenizer, "apply_chat_template")
|
||||
and self.tokenizer.chat_template
|
||||
):
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = TOKENIZER.encode(prompt)
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
|
||||
return mx.array(prompt)
|
||||
|
||||
@@ -394,13 +403,24 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||
prompt_text = self.body["prompt"]
|
||||
prompt = TOKENIZER.encode(prompt_text)
|
||||
|
||||
prompt = self.tokenizer.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
|
||||
|
||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||
def run(
|
||||
host: str,
|
||||
port: int,
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
server_class=HTTPServer,
|
||||
handler_class=APIHandler,
|
||||
):
|
||||
server_address = (host, port)
|
||||
httpd = server_class(server_address, handler_class)
|
||||
httpd = server_class(
|
||||
server_address,
|
||||
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
|
||||
)
|
||||
warnings.warn(
|
||||
"mlx_lm.server is not recommended for production as "
|
||||
"it only implements basic security checks."
|
||||
@@ -444,11 +464,10 @@ def main():
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
|
||||
MODEL, TOKENIZER = load(
|
||||
model, tokenizer = load(
|
||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
run(args.host, args.port)
|
||||
run(args.host, args.port, model, tokenizer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -360,7 +360,7 @@ def load(
|
||||
tokenizer_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
@@ -374,7 +374,7 @@ def load(
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
Returns:
|
||||
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
|
Reference in New Issue
Block a user