mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add /v1/models endpoint to mlx_lm.server (#984)
* Add 'models' endpoint to server * Add test for new 'models' server endpoint * Check hf_cache for mlx models * update tests to check hf_cache for models * simplify test * doc --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
76710f61af
commit
d812516d3d
@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||
rlative to the directory the server was started in.
|
||||
|
||||
### List Models
|
||||
|
||||
Use the `v1/models` endpoint to list available models:
|
||||
|
||||
```shell
|
||||
curl localhost:8080/v1/models -H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
This will return a list of locally available models where each model in the
|
||||
list contains the following fields:
|
||||
|
||||
- `"id"`: The Hugging Face repo id.
|
||||
- `"created"`: A timestamp representing the model creation time.
|
||||
|
@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
from .utils import generate_step, load
|
||||
|
||||
@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
prompt = self.tokenizer.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
|
||||
def do_GET(self):
|
||||
"""
|
||||
Respond to a GET request from a client.
|
||||
"""
|
||||
if self.path == "/v1/models":
|
||||
self.handle_models_request()
|
||||
else:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not Found")
|
||||
|
||||
def handle_models_request(self):
|
||||
"""
|
||||
Handle a GET request for the /v1/models endpoint.
|
||||
"""
|
||||
self._set_completion_headers(200)
|
||||
self.end_headers()
|
||||
|
||||
# Scan the cache directory for downloaded mlx models
|
||||
hf_cache_info = scan_cache_dir()
|
||||
downloaded_models = [
|
||||
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
|
||||
]
|
||||
|
||||
# Create a list of available models
|
||||
models = [
|
||||
{
|
||||
"id": repo.repo_id,
|
||||
"object": "model",
|
||||
"created": self.created,
|
||||
}
|
||||
for repo in downloaded_models
|
||||
]
|
||||
|
||||
response = {"object": "list", "data": models}
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
self.wfile.write(response_json)
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
def run(
|
||||
host: str,
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import http
|
||||
import json
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_models(self):
|
||||
url = f"http://localhost:{self.port}/v1/models"
|
||||
response = requests.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response_body = json.loads(response.text)
|
||||
self.assertEqual(response_body["object"], "list")
|
||||
self.assertIsInstance(response_body["data"], list)
|
||||
self.assertGreater(len(response_body["data"]), 0)
|
||||
model = response_body["data"][0]
|
||||
self.assertIn("id", model)
|
||||
self.assertEqual(model["object"], "model")
|
||||
self.assertIn("created", model)
|
||||
|
||||
def test_sequence_overlap(self):
|
||||
from mlx_lm.server import sequence_overlap
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user