Add /v1/models endpoint to mlx_lm.server (#984)

* Add 'models' endpoint to server

* Add test for new 'models' server endpoint

* Check hf_cache for mlx models

* update tests to check hf_cache for models

* simplify test

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
jamesm131 2024-09-29 00:21:11 +10:00 committed by GitHub
parent 76710f61af
commit d812516d3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 0 deletions

View File

@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
- `adapters`: (Optional) A string path to low-rank adapters. The path must be - `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in. rlative to the directory the server was started in.
### List Models
Use the `v1/models` endpoint to list available models:
```shell
curl localhost:8080/v1/models -H "Content-Type: application/json"
```
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.

View File

@ -11,6 +11,7 @@ from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx import mlx.core as mx
from huggingface_hub import scan_cache_dir
from .utils import generate_step, load from .utils import generate_step, load
@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = self.tokenizer.encode(prompt_text) prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt) return mx.array(prompt)
def do_GET(self):
"""
Respond to a GET request from a client.
"""
if self.path == "/v1/models":
self.handle_models_request()
else:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
def handle_models_request(self):
"""
Handle a GET request for the /v1/models endpoint.
"""
self._set_completion_headers(200)
self.end_headers()
# Scan the cache directory for downloaded mlx models
hf_cache_info = scan_cache_dir()
downloaded_models = [
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
]
# Create a list of available models
models = [
{
"id": repo.repo_id,
"object": "model",
"created": self.created,
}
for repo in downloaded_models
]
response = {"object": "list", "data": models}
response_json = json.dumps(response).encode()
self.wfile.write(response_json)
self.wfile.flush()
def run( def run(
host: str, host: str,

View File

@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import http import http
import json
import threading import threading
import unittest import unittest
@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body) self.assertIn("id", response_body)
self.assertIn("choices", response_body) self.assertIn("choices", response_body)
def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)
self.assertEqual(response.status_code, 200)
response_body = json.loads(response.text)
self.assertEqual(response_body["object"], "list")
self.assertIsInstance(response_body["data"], list)
self.assertGreater(len(response_body["data"]), 0)
model = response_body["data"][0]
self.assertIn("id", model)
self.assertEqual(model["object"], "model")
self.assertIn("created", model)
def test_sequence_overlap(self): def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap from mlx_lm.server import sequence_overlap