mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Add /v1/models endpoint to mlx_lm.server (#984)
* Add 'models' endpoint to server * Add test for new 'models' server endpoint * Check hf_cache for mlx models * update tests to check hf_cache for models * simplify test * doc --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
from .utils import generate_step, load
|
||||
|
||||
@@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
prompt = self.tokenizer.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
|
||||
def do_GET(self):
|
||||
"""
|
||||
Respond to a GET request from a client.
|
||||
"""
|
||||
if self.path == "/v1/models":
|
||||
self.handle_models_request()
|
||||
else:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not Found")
|
||||
|
||||
def handle_models_request(self):
|
||||
"""
|
||||
Handle a GET request for the /v1/models endpoint.
|
||||
"""
|
||||
self._set_completion_headers(200)
|
||||
self.end_headers()
|
||||
|
||||
# Scan the cache directory for downloaded mlx models
|
||||
hf_cache_info = scan_cache_dir()
|
||||
downloaded_models = [
|
||||
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
|
||||
]
|
||||
|
||||
# Create a list of available models
|
||||
models = [
|
||||
{
|
||||
"id": repo.repo_id,
|
||||
"object": "model",
|
||||
"created": self.created,
|
||||
}
|
||||
for repo in downloaded_models
|
||||
]
|
||||
|
||||
response = {"object": "list", "data": models}
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
self.wfile.write(response_json)
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
def run(
|
||||
host: str,
|
||||
|
Reference in New Issue
Block a user