Prevent llms/mlx_lm from serving the local directory as a webserver (#498)

* Don't serve local directory

BaseHTTPRequestHandler serves the current directory by default. Definitely not intended behaviour. Remove the "do_HEAD" and "do_GET" methods.

* Fix typo in method name

I assume hanlde_stream was intended to be called handle_stream

* Fix outdated typehint

load_model returns nn.Module, however fetch_from_hub was not updated to reflect the change

* Add some more type hints

* Add warnings for using in prod

Add a warning to README and runtime, discouraging use in production. The warning is the same as on the python docs for HTTPServer https://docs.python.org/3/library/http.server.html

* format

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Y4hL 2024-02-28 05:40:42 +02:00 committed by GitHub
parent 676e574eff
commit ea92f623d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 9 deletions

View File

@ -12,3 +12,4 @@ MLX Examples was developed with contributions from the following individuals:
- Shunta Saito: Added support for PLaMo models. - Shunta Saito: Added support for PLaMo models.
- Gabrijel Boduljak: Implemented `CLIP`. - Gabrijel Boduljak: Implemented `CLIP`.
- Markus Enzweiler: Added the `cvae` examples. - Markus Enzweiler: Added the `cvae` examples.
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example

View File

@ -4,6 +4,10 @@ You use `mlx-lm` to make an HTTP API for generating text with any supported
model. The HTTP API is intended to be similar to the [OpenAI chat model. The HTTP API is intended to be similar to the [OpenAI chat
API](https://platform.openai.com/docs/api-reference). API](https://platform.openai.com/docs/api-reference).
> [!NOTE]
> The MLX LM server is not recommended for production as it only implements
> basic security checks.
Start the server with: Start the server with:
```shell ```shell
@ -61,5 +65,9 @@ curl localhost:8080/v1/chat/completions \
- `top_p`: (Optional) A float specifying the nucleus sampling parameter. - `top_p`: (Optional) A float specifying the nucleus sampling parameter.
Defaults to `1.0`. Defaults to `1.0`.
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`.
- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`. - `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
Defaults to `1.0`.
- `repetition_context_size`: (Optional) The size of the context window for
applying repetition penalty. Defaults to `20`.

View File

@ -7,6 +7,7 @@ import shutil
from pathlib import Path from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import numpy as np import numpy as np
import yaml import yaml
from mlx.utils import tree_flatten, tree_map from mlx.utils import tree_flatten, tree_map
@ -68,7 +69,7 @@ def slerp(t, w1, w2, eps=1e-5):
return s1 * w1 + s2 * w2 return s1 * w1 + s2 * w2
def merge_models(base_model, model, config): def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
method = config.get("method", None) method = config.get("method", None)
if method != "slerp": if method != "slerp":
raise ValueError(f"Merge method {method} not supported") raise ValueError(f"Merge method {method} not supported")

View File

@ -4,6 +4,7 @@ import argparse
import json import json
import time import time
import uuid import uuid
import warnings
from collections import namedtuple from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Callable, List, Optional from typing import Callable, List, Optional
@ -158,6 +159,14 @@ def create_completion_chunk_response(completion_id, requested_model, next_chunk)
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs):
# Prevent exposing local directory by deleting HEAD and GET methods
delattr(self, "do_HEAD")
delattr(self, "do_GET")
super().__init__(*args, **kwargs)
def _set_headers(self, status_code=200): def _set_headers(self, status_code=200):
self.send_response(status_code) self.send_response(status_code)
self.send_header("Content-type", "application/json") self.send_header("Content-type", "application/json")
@ -227,7 +236,7 @@ class APIHandler(BaseHTTPRequestHandler):
text = _tokenizer.decode(tokens) text = _tokenizer.decode(tokens)
return response_creator(response_id, requested_model, prompt, tokens, text) return response_creator(response_id, requested_model, prompt, tokens, text)
def hanlde_stream( def handle_stream(
self, self,
prompt: mx.array, prompt: mx.array,
response_id: str, response_id: str,
@ -306,7 +315,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.wfile.write(f"data: [DONE]\n\n".encode()) self.wfile.write(f"data: [DONE]\n\n".encode())
self.wfile.flush() self.wfile.flush()
def handle_chat_completions(self, post_data): def handle_chat_completions(self, post_data: bytes):
body = json.loads(post_data.decode("utf-8")) body = json.loads(post_data.decode("utf-8"))
chat_id = f"chatcmpl-{uuid.uuid4()}" chat_id = f"chatcmpl-{uuid.uuid4()}"
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
@ -352,7 +361,7 @@ class APIHandler(BaseHTTPRequestHandler):
create_chat_response, create_chat_response,
) )
else: else:
self.hanlde_stream( self.handle_stream(
prompt, prompt,
chat_id, chat_id,
requested_model, requested_model,
@ -366,7 +375,7 @@ class APIHandler(BaseHTTPRequestHandler):
create_chat_chunk_response, create_chat_chunk_response,
) )
def handle_completions(self, post_data): def handle_completions(self, post_data: bytes):
body = json.loads(post_data.decode("utf-8")) body = json.loads(post_data.decode("utf-8"))
completion_id = f"cmpl-{uuid.uuid4()}" completion_id = f"cmpl-{uuid.uuid4()}"
prompt_text = body["prompt"] prompt_text = body["prompt"]
@ -403,7 +412,7 @@ class APIHandler(BaseHTTPRequestHandler):
create_completion_response, create_completion_response,
) )
else: else:
self.hanlde_stream( self.handle_stream(
prompt, prompt,
completion_id, completion_id,
requested_model, requested_model,
@ -421,6 +430,10 @@ class APIHandler(BaseHTTPRequestHandler):
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
server_address = (host, port) server_address = (host, port)
httpd = server_class(server_address, handler_class) httpd = server_class(server_address, handler_class)
warnings.warn(
"mlx_lm.server is not recommended for production as "
"it only implements basic security checks."
)
print(f"Starting httpd at {host} on port {port}...") print(f"Starting httpd at {host} on port {port}...")
httpd.serve_forever() httpd.serve_forever()

View File

@ -390,7 +390,7 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[Dict, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy) model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path)