mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Prevent llms/mlx_lm from serving the local directory as a webserver (#498)
* Don't serve local directory BaseHTTPRequestHandler serves the current directory by default. Definitely not intended behaviour. Remove the "do_HEAD" and "do_GET" methods. * Fix typo in method name I assume hanlde_stream was intended to be called handle_stream * Fix outdated typehint load_model returns nn.Module, however fetch_from_hub was not updated to reflect the change * Add some more type hints * Add warnings for using in prod Add a warning to README and runtime, discouraging use in production. The warning is the same as on the python docs for HTTPServer https://docs.python.org/3/library/http.server.html * format * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
676e574eff
commit
ea92f623d6
@ -12,3 +12,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Shunta Saito: Added support for PLaMo models.
|
||||
- Gabrijel Boduljak: Implemented `CLIP`.
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
|
||||
|
@ -4,6 +4,10 @@ You use `mlx-lm` to make an HTTP API for generating text with any supported
|
||||
model. The HTTP API is intended to be similar to the [OpenAI chat
|
||||
API](https://platform.openai.com/docs/api-reference).
|
||||
|
||||
> [!NOTE]
|
||||
> The MLX LM server is not recommended for production as it only implements
|
||||
> basic security checks.
|
||||
|
||||
Start the server with:
|
||||
|
||||
```shell
|
||||
@ -61,5 +65,9 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
|
||||
Defaults to `1.0`.
|
||||
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`.
|
||||
- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`.
|
||||
|
||||
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `repetition_context_size`: (Optional) The size of the context window for
|
||||
applying repetition penalty. Defaults to `20`.
|
||||
|
@ -7,6 +7,7 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
@ -68,7 +69,7 @@ def slerp(t, w1, w2, eps=1e-5):
|
||||
return s1 * w1 + s2 * w2
|
||||
|
||||
|
||||
def merge_models(base_model, model, config):
|
||||
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
||||
method = config.get("method", None)
|
||||
if method != "slerp":
|
||||
raise ValueError(f"Merge method {method} not supported")
|
||||
|
@ -4,6 +4,7 @@ import argparse
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Callable, List, Optional
|
||||
@ -158,6 +159,14 @@ def create_completion_chunk_response(completion_id, requested_model, next_chunk)
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Prevent exposing local directory by deleting HEAD and GET methods
|
||||
delattr(self, "do_HEAD")
|
||||
delattr(self, "do_GET")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _set_headers(self, status_code=200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -227,7 +236,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
text = _tokenizer.decode(tokens)
|
||||
return response_creator(response_id, requested_model, prompt, tokens, text)
|
||||
|
||||
def hanlde_stream(
|
||||
def handle_stream(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
response_id: str,
|
||||
@ -306,7 +315,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.wfile.write(f"data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
def handle_chat_completions(self, post_data):
|
||||
def handle_chat_completions(self, post_data: bytes):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template:
|
||||
@ -352,7 +361,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
create_chat_response,
|
||||
)
|
||||
else:
|
||||
self.hanlde_stream(
|
||||
self.handle_stream(
|
||||
prompt,
|
||||
chat_id,
|
||||
requested_model,
|
||||
@ -366,7 +375,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
create_chat_chunk_response,
|
||||
)
|
||||
|
||||
def handle_completions(self, post_data):
|
||||
def handle_completions(self, post_data: bytes):
|
||||
body = json.loads(post_data.decode("utf-8"))
|
||||
completion_id = f"cmpl-{uuid.uuid4()}"
|
||||
prompt_text = body["prompt"]
|
||||
@ -403,7 +412,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
create_completion_response,
|
||||
)
|
||||
else:
|
||||
self.hanlde_stream(
|
||||
self.handle_stream(
|
||||
prompt,
|
||||
completion_id,
|
||||
requested_model,
|
||||
@ -421,6 +430,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
|
||||
server_address = (host, port)
|
||||
httpd = server_class(server_address, handler_class)
|
||||
warnings.warn(
|
||||
"mlx_lm.server is not recommended for production as "
|
||||
"it only implements basic security checks."
|
||||
)
|
||||
print(f"Starting httpd at {host} on port {port}...")
|
||||
httpd.serve_forever()
|
||||
|
||||
|
@ -390,7 +390,7 @@ def load(
|
||||
|
||||
def fetch_from_hub(
|
||||
model_path: Path, lazy: bool = False
|
||||
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
|
||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||
model = load_model(model_path, lazy)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
Loading…
Reference in New Issue
Block a user