From 8d5cf5b0c8279500fc4199f44825f57b5344b4d5 Mon Sep 17 00:00:00 2001 From: Aaron Ng Date: Mon, 22 Apr 2024 07:50:06 -0700 Subject: [PATCH] use logging in mlx server (#705) --- llms/mlx_lm/server.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 9de306da..5464dd1a 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -2,6 +2,7 @@ import argparse import json +import logging import time import uuid import warnings @@ -116,6 +117,8 @@ class APIHandler(BaseHTTPRequestHandler): content_length = int(self.headers["Content-Length"]) raw_body = self.rfile.read(content_length) self.body = json.loads(raw_body.decode()) + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}") assert isinstance( self.body, dict ), f"Request should be dict, but got {type(self.body)}" @@ -240,6 +243,7 @@ class APIHandler(BaseHTTPRequestHandler): tokens = [] finish_reason = "length" stop_sequence_suffix = None + logging.debug(f"Starting completion:") for (token, _), _ in zip( generate_step( prompt=prompt, @@ -253,6 +257,7 @@ class APIHandler(BaseHTTPRequestHandler): range(self.max_tokens), ): detokenizer.add_token(token) + logging.debug(detokenizer.text) tokens.append(token) stop_condition = stopping_criteria( tokens, stop_id_sequences, self.tokenizer.eos_token_id @@ -274,6 +279,8 @@ class APIHandler(BaseHTTPRequestHandler): response = self.generate_response(text, finish_reason, len(prompt), len(tokens)) response_json = json.dumps(response).encode() + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") # Send an additional Content-Length header when it is known self.send_header("Content-Length", str(len(response_json))) @@ -307,6 +314,7 @@ class APIHandler(BaseHTTPRequestHandler): # to check for stop conditions before writing to the stream. stop_sequence_buffer = [] stop_sequence_suffix = None + logging.debug(f"Starting stream:") for (token, _), _ in zip( generate_step( prompt=prompt, @@ -319,6 +327,7 @@ class APIHandler(BaseHTTPRequestHandler): range(self.max_tokens), ): detokenizer.add_token(token) + logging.debug(detokenizer.text) tokens.append(token) stop_sequence_buffer.append(token) @@ -425,7 +434,7 @@ def run( "mlx_lm.server is not recommended for production as " "it only implements basic security checks." ) - print(f"Starting httpd at {host} on port {port}...") + logging.info(f"Starting httpd at {host} on port {port}...") httpd.serve_forever() @@ -459,8 +468,20 @@ def main(): action="store_true", help="Enable trusting remote code for tokenizer", ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level (default: INFO)", + ) args = parser.parse_args() + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), None), + format="%(asctime)s - %(levelname)s - %(message)s", + ) + # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}