Add /v1/models endpoint to mlx_lm.server (#984)

* Add 'models' endpoint to server

* Add test for new 'models' server endpoint

* Check hf_cache for mlx models

* update tests to check hf_cache for models

* simplify test

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
jamesm131 2024-09-29 00:21:11 +10:00 committed by GitHub
parent 76710f61af
commit d812516d3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 0 deletions

View File

@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
### List Models
Use the `v1/models` endpoint to list available models:
```shell
curl localhost:8080/v1/models -H "Content-Type: application/json"
```
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.

View File

@ -11,6 +11,7 @@ from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from .utils import generate_step, load
@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
def do_GET(self):
"""
Respond to a GET request from a client.
"""
if self.path == "/v1/models":
self.handle_models_request()
else:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
def handle_models_request(self):
"""
Handle a GET request for the /v1/models endpoint.
"""
self._set_completion_headers(200)
self.end_headers()
# Scan the cache directory for downloaded mlx models
hf_cache_info = scan_cache_dir()
downloaded_models = [
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
]
# Create a list of available models
models = [
{
"id": repo.repo_id,
"object": "model",
"created": self.created,
}
for repo in downloaded_models
]
response = {"object": "list", "data": models}
response_json = json.dumps(response).encode()
self.wfile.write(response_json)
self.wfile.flush()
def run(
host: str,

View File

@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc.
import http
import json
import threading
import unittest
@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)
self.assertEqual(response.status_code, 200)
response_body = json.loads(response.text)
self.assertEqual(response_body["object"], "list")
self.assertIsInstance(response_body["data"], list)
self.assertGreater(len(response_body["data"]), 0)
model = response_body["data"][0]
self.assertIn("id", model)
self.assertEqual(model["object"], "model")
self.assertIn("created", model)
def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap