diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 9c42d410..55be1c9c 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \ - `adapters`: (Optional) A string path to low-rank adapters. The path must be rlative to the directory the server was started in. + +### List Models + +Use the `v1/models` endpoint to list available models: + +```shell +curl localhost:8080/v1/models -H "Content-Type: application/json" +``` + +This will return a list of locally available models where each model in the +list contains the following fields: + +- `"id"`: The Hugging Face repo id. +- `"created"`: A timestamp representing the model creation time. diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 79ac1836..f2d8b86a 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union import mlx.core as mx +from huggingface_hub import scan_cache_dir from .utils import generate_step, load @@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler): prompt = self.tokenizer.encode(prompt_text) return mx.array(prompt) + def do_GET(self): + """ + Respond to a GET request from a client. + """ + if self.path == "/v1/models": + self.handle_models_request() + else: + self._set_completion_headers(404) + self.end_headers() + self.wfile.write(b"Not Found") + + def handle_models_request(self): + """ + Handle a GET request for the /v1/models endpoint. + """ + self._set_completion_headers(200) + self.end_headers() + + # Scan the cache directory for downloaded mlx models + hf_cache_info = scan_cache_dir() + downloaded_models = [ + repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id + ] + + # Create a list of available models + models = [ + { + "id": repo.repo_id, + "object": "model", + "created": self.created, + } + for repo in downloaded_models + ] + + response = {"object": "list", "data": models} + + response_json = json.dumps(response).encode() + self.wfile.write(response_json) + self.wfile.flush() + def run( host: str, diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index baea664a..cbcccfbe 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -1,5 +1,7 @@ # Copyright © 2024 Apple Inc. + import http +import json import threading import unittest @@ -77,6 +79,19 @@ class TestServer(unittest.TestCase): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_handle_models(self): + url = f"http://localhost:{self.port}/v1/models" + response = requests.get(url) + self.assertEqual(response.status_code, 200) + response_body = json.loads(response.text) + self.assertEqual(response_body["object"], "list") + self.assertIsInstance(response_body["data"], list) + self.assertGreater(len(response_body["data"]), 0) + model = response_body["data"][0] + self.assertIn("id", model) + self.assertEqual(model["object"], "model") + self.assertIn("created", model) + def test_sequence_overlap(self): from mlx_lm.server import sequence_overlap