use logging in mlx server (#705)

This commit is contained in:
Aaron Ng 2024-04-22 07:50:06 -07:00 committed by GitHub
parent f20e68fcc0
commit 8d5cf5b0c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
import argparse import argparse
import json import json
import logging
import time import time
import uuid import uuid
import warnings import warnings
@ -116,6 +117,8 @@ class APIHandler(BaseHTTPRequestHandler):
content_length = int(self.headers["Content-Length"]) content_length = int(self.headers["Content-Length"])
raw_body = self.rfile.read(content_length) raw_body = self.rfile.read(content_length)
self.body = json.loads(raw_body.decode()) self.body = json.loads(raw_body.decode())
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}")
assert isinstance( assert isinstance(
self.body, dict self.body, dict
), f"Request should be dict, but got {type(self.body)}" ), f"Request should be dict, but got {type(self.body)}"
@ -240,6 +243,7 @@ class APIHandler(BaseHTTPRequestHandler):
tokens = [] tokens = []
finish_reason = "length" finish_reason = "length"
stop_sequence_suffix = None stop_sequence_suffix = None
logging.debug(f"Starting completion:")
for (token, _), _ in zip( for (token, _), _ in zip(
generate_step( generate_step(
prompt=prompt, prompt=prompt,
@ -253,6 +257,7 @@ class APIHandler(BaseHTTPRequestHandler):
range(self.max_tokens), range(self.max_tokens),
): ):
detokenizer.add_token(token) detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token) tokens.append(token)
stop_condition = stopping_criteria( stop_condition = stopping_criteria(
tokens, stop_id_sequences, self.tokenizer.eos_token_id tokens, stop_id_sequences, self.tokenizer.eos_token_id
@ -274,6 +279,8 @@ class APIHandler(BaseHTTPRequestHandler):
response = self.generate_response(text, finish_reason, len(prompt), len(tokens)) response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
response_json = json.dumps(response).encode() response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
# Send an additional Content-Length header when it is known # Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json))) self.send_header("Content-Length", str(len(response_json)))
@ -307,6 +314,7 @@ class APIHandler(BaseHTTPRequestHandler):
# to check for stop conditions before writing to the stream. # to check for stop conditions before writing to the stream.
stop_sequence_buffer = [] stop_sequence_buffer = []
stop_sequence_suffix = None stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip( for (token, _), _ in zip(
generate_step( generate_step(
prompt=prompt, prompt=prompt,
@ -319,6 +327,7 @@ class APIHandler(BaseHTTPRequestHandler):
range(self.max_tokens), range(self.max_tokens),
): ):
detokenizer.add_token(token) detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token) tokens.append(token)
stop_sequence_buffer.append(token) stop_sequence_buffer.append(token)
@ -425,7 +434,7 @@ def run(
"mlx_lm.server is not recommended for production as " "mlx_lm.server is not recommended for production as "
"it only implements basic security checks." "it only implements basic security checks."
) )
print(f"Starting httpd at {host} on port {port}...") logging.info(f"Starting httpd at {host} on port {port}...")
httpd.serve_forever() httpd.serve_forever()
@ -459,8 +468,20 @@ def main():
action="store_true", action="store_true",
help="Enable trusting remote code for tokenizer", help="Enable trusting remote code for tokenizer",
) )
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level (default: INFO)",
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), None),
format="%(asctime)s - %(levelname)s - %(message)s",
)
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}