mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add /v1/models endpoint to mlx_lm.server (#984)
* Add 'models' endpoint to server * Add test for new 'models' server endpoint * Check hf_cache for mlx models * update tests to check hf_cache for models * simplify test * doc --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
76710f61af
commit
d812516d3d
@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
|
|
||||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||||
rlative to the directory the server was started in.
|
rlative to the directory the server was started in.
|
||||||
|
|
||||||
|
### List Models
|
||||||
|
|
||||||
|
Use the `v1/models` endpoint to list available models:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl localhost:8080/v1/models -H "Content-Type: application/json"
|
||||||
|
```
|
||||||
|
|
||||||
|
This will return a list of locally available models where each model in the
|
||||||
|
list contains the following fields:
|
||||||
|
|
||||||
|
- `"id"`: The Hugging Face repo id.
|
||||||
|
- `"created"`: A timestamp representing the model creation time.
|
||||||
|
@ -11,6 +11,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from huggingface_hub import scan_cache_dir
|
||||||
|
|
||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
|
|
||||||
@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
prompt = self.tokenizer.encode(prompt_text)
|
prompt = self.tokenizer.encode(prompt_text)
|
||||||
return mx.array(prompt)
|
return mx.array(prompt)
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
"""
|
||||||
|
Respond to a GET request from a client.
|
||||||
|
"""
|
||||||
|
if self.path == "/v1/models":
|
||||||
|
self.handle_models_request()
|
||||||
|
else:
|
||||||
|
self._set_completion_headers(404)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"Not Found")
|
||||||
|
|
||||||
|
def handle_models_request(self):
|
||||||
|
"""
|
||||||
|
Handle a GET request for the /v1/models endpoint.
|
||||||
|
"""
|
||||||
|
self._set_completion_headers(200)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
# Scan the cache directory for downloaded mlx models
|
||||||
|
hf_cache_info = scan_cache_dir()
|
||||||
|
downloaded_models = [
|
||||||
|
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a list of available models
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"id": repo.repo_id,
|
||||||
|
"object": "model",
|
||||||
|
"created": self.created,
|
||||||
|
}
|
||||||
|
for repo in downloaded_models
|
||||||
|
]
|
||||||
|
|
||||||
|
response = {"object": "list", "data": models}
|
||||||
|
|
||||||
|
response_json = json.dumps(response).encode()
|
||||||
|
self.wfile.write(response_json)
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
host: str,
|
host: str,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import http
|
import http
|
||||||
|
import json
|
||||||
import threading
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
|
|||||||
self.assertIn("id", response_body)
|
self.assertIn("id", response_body)
|
||||||
self.assertIn("choices", response_body)
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
|
def test_handle_models(self):
|
||||||
|
url = f"http://localhost:{self.port}/v1/models"
|
||||||
|
response = requests.get(url)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
response_body = json.loads(response.text)
|
||||||
|
self.assertEqual(response_body["object"], "list")
|
||||||
|
self.assertIsInstance(response_body["data"], list)
|
||||||
|
self.assertGreater(len(response_body["data"]), 0)
|
||||||
|
model = response_body["data"][0]
|
||||||
|
self.assertIn("id", model)
|
||||||
|
self.assertEqual(model["object"], "model")
|
||||||
|
self.assertIn("created", model)
|
||||||
|
|
||||||
def test_sequence_overlap(self):
|
def test_sequence_overlap(self):
|
||||||
from mlx_lm.server import sequence_overlap
|
from mlx_lm.server import sequence_overlap
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user