diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 6bd419a9..3bca9bd3 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -12,3 +12,4 @@ MLX Examples was developed with contributions from the following individuals: - Shunta Saito: Added support for PLaMo models. - Gabrijel Boduljak: Implemented `CLIP`. - Markus Enzweiler: Added the `cvae` examples. +- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index e7dd5578..68bb3545 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -4,6 +4,10 @@ You use `mlx-lm` to make an HTTP API for generating text with any supported model. The HTTP API is intended to be similar to the [OpenAI chat API](https://platform.openai.com/docs/api-reference). +> [!NOTE] +> The MLX LM server is not recommended for production as it only implements +> basic security checks. + Start the server with: ```shell @@ -61,5 +65,9 @@ curl localhost:8080/v1/chat/completions \ - `top_p`: (Optional) A float specifying the nucleus sampling parameter. Defaults to `1.0`. -- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. Defaults to `1.0`. -- `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`. \ No newline at end of file + +- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. + Defaults to `1.0`. + +- `repetition_context_size`: (Optional) The size of the context window for + applying repetition penalty. Defaults to `20`. diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py index d2f54316..46fb87a8 100644 --- a/llms/mlx_lm/merge.py +++ b/llms/mlx_lm/merge.py @@ -7,6 +7,7 @@ import shutil from pathlib import Path import mlx.core as mx +import mlx.nn as nn import numpy as np import yaml from mlx.utils import tree_flatten, tree_map @@ -68,7 +69,7 @@ def slerp(t, w1, w2, eps=1e-5): return s1 * w1 + s2 * w2 -def merge_models(base_model, model, config): +def merge_models(base_model: nn.Module, model: nn.Module, config: dict): method = config.get("method", None) if method != "slerp": raise ValueError(f"Merge method {method} not supported") diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 894c2e37..3fe04c88 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -4,6 +4,7 @@ import argparse import json import time import uuid +import warnings from collections import namedtuple from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Callable, List, Optional @@ -158,6 +159,14 @@ def create_completion_chunk_response(completion_id, requested_model, next_chunk) class APIHandler(BaseHTTPRequestHandler): + + def __init__(self, *args, **kwargs): + # Prevent exposing local directory by deleting HEAD and GET methods + delattr(self, "do_HEAD") + delattr(self, "do_GET") + + super().__init__(*args, **kwargs) + def _set_headers(self, status_code=200): self.send_response(status_code) self.send_header("Content-type", "application/json") @@ -227,7 +236,7 @@ class APIHandler(BaseHTTPRequestHandler): text = _tokenizer.decode(tokens) return response_creator(response_id, requested_model, prompt, tokens, text) - def hanlde_stream( + def handle_stream( self, prompt: mx.array, response_id: str, @@ -306,7 +315,7 @@ class APIHandler(BaseHTTPRequestHandler): self.wfile.write(f"data: [DONE]\n\n".encode()) self.wfile.flush() - def handle_chat_completions(self, post_data): + def handle_chat_completions(self, post_data: bytes): body = json.loads(post_data.decode("utf-8")) chat_id = f"chatcmpl-{uuid.uuid4()}" if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: @@ -352,7 +361,7 @@ class APIHandler(BaseHTTPRequestHandler): create_chat_response, ) else: - self.hanlde_stream( + self.handle_stream( prompt, chat_id, requested_model, @@ -366,7 +375,7 @@ class APIHandler(BaseHTTPRequestHandler): create_chat_chunk_response, ) - def handle_completions(self, post_data): + def handle_completions(self, post_data: bytes): body = json.loads(post_data.decode("utf-8")) completion_id = f"cmpl-{uuid.uuid4()}" prompt_text = body["prompt"] @@ -403,7 +412,7 @@ class APIHandler(BaseHTTPRequestHandler): create_completion_response, ) else: - self.hanlde_stream( + self.handle_stream( prompt, completion_id, requested_model, @@ -421,6 +430,10 @@ class APIHandler(BaseHTTPRequestHandler): def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler): server_address = (host, port) httpd = server_class(server_address, handler_class) + warnings.warn( + "mlx_lm.server is not recommended for production as " + "it only implements basic security checks." + ) print(f"Starting httpd at {host} on port {port}...") httpd.serve_forever() diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ca78088c..c3e0b191 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -390,7 +390,7 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False -) -> Tuple[Dict, dict, PreTrainedTokenizer]: +) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: model = load_model(model_path, lazy) config = AutoConfig.from_pretrained(model_path)