Prevent llms/mlx_lm from serving the local directory as a webserver (#498)

* Don't serve local directory

BaseHTTPRequestHandler serves the current directory by default. Definitely not intended behaviour. Remove the "do_HEAD" and "do_GET" methods.

* Fix typo in method name

I assume hanlde_stream was intended to be called handle_stream

* Fix outdated typehint

load_model returns nn.Module, however fetch_from_hub was not updated to reflect the change

* Add some more type hints

* Add warnings for using in prod

Add a warning to README and runtime, discouraging use in production. The warning is the same as on the python docs for HTTPServer https://docs.python.org/3/library/http.server.html

* format

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Y4hL
2024-02-28 05:40:42 +02:00
committed by GitHub
parent 676e574eff
commit ea92f623d6
5 changed files with 32 additions and 9 deletions

View File

@@ -4,6 +4,7 @@ import argparse
import json
import time
import uuid
import warnings
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Callable, List, Optional
@@ -158,6 +159,14 @@ def create_completion_chunk_response(completion_id, requested_model, next_chunk)
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs):
# Prevent exposing local directory by deleting HEAD and GET methods
delattr(self, "do_HEAD")
delattr(self, "do_GET")
super().__init__(*args, **kwargs)
def _set_headers(self, status_code=200):
self.send_response(status_code)
self.send_header("Content-type", "application/json")
@@ -227,7 +236,7 @@ class APIHandler(BaseHTTPRequestHandler):
text = _tokenizer.decode(tokens)
return response_creator(response_id, requested_model, prompt, tokens, text)
def hanlde_stream(
def handle_stream(
self,
prompt: mx.array,
response_id: str,
@@ -306,7 +315,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.wfile.write(f"data: [DONE]\n\n".encode())
self.wfile.flush()
def handle_chat_completions(self, post_data):
def handle_chat_completions(self, post_data: bytes):
body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
@@ -352,7 +361,7 @@ class APIHandler(BaseHTTPRequestHandler):
create_chat_response,
)
else:
self.hanlde_stream(
self.handle_stream(
prompt,
chat_id,
requested_model,
@@ -366,7 +375,7 @@ class APIHandler(BaseHTTPRequestHandler):
create_chat_chunk_response,
)
def handle_completions(self, post_data):
def handle_completions(self, post_data: bytes):
body = json.loads(post_data.decode("utf-8"))
completion_id = f"cmpl-{uuid.uuid4()}"
prompt_text = body["prompt"]
@@ -403,7 +412,7 @@ class APIHandler(BaseHTTPRequestHandler):
create_completion_response,
)
else:
self.hanlde_stream(
self.handle_stream(
prompt,
completion_id,
requested_model,
@@ -421,6 +430,10 @@ class APIHandler(BaseHTTPRequestHandler):
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
server_address = (host, port)
httpd = server_class(server_address, handler_class)
warnings.warn(
"mlx_lm.server is not recommended for production as "
"it only implements basic security checks."
)
print(f"Starting httpd at {host} on port {port}...")
httpd.serve_forever()