mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
fix tests
This commit is contained in:
parent
d85010bf4b
commit
d6222ae7ff
@ -30,6 +30,13 @@ from .models.cache import make_prompt_cache
|
||||
from .utils import generate_step, load
|
||||
|
||||
|
||||
def get_system_fingerprint():
|
||||
return (
|
||||
f"{__version__}-{mx.__version__}-{platform.platform()}-"
|
||||
f"{mx.metal.device_info().get('architecture', '')}"
|
||||
)
|
||||
|
||||
|
||||
class StopCondition(NamedTuple):
|
||||
stop_met: bool
|
||||
trim_length: int
|
||||
@ -180,9 +187,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: ModelProvider,
|
||||
prompt_cache: PromptCache,
|
||||
system_fingerprint: str,
|
||||
*args,
|
||||
prompt_cache: Optional[PromptCache] = None,
|
||||
system_fingerprint: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -190,8 +197,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
"""
|
||||
self.created = int(time.time())
|
||||
self.model_provider = model_provider
|
||||
self.prompt_cache = prompt_cache
|
||||
self.system_fingerprint = system_fingerprint
|
||||
self.prompt_cache = prompt_cache or PromptCache()
|
||||
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _set_cors_headers(self):
|
||||
@ -725,14 +732,14 @@ def run(
|
||||
):
|
||||
server_address = (host, port)
|
||||
prompt_cache = PromptCache()
|
||||
system_fingerprint = (
|
||||
f"{__version__}-{mx.__version__}-{platform.platform()}-"
|
||||
f"{mx.metal.device_info().get('architecture', '')}"
|
||||
)
|
||||
httpd = server_class(
|
||||
server_address,
|
||||
lambda *args, **kwargs: handler_class(
|
||||
model_provider, prompt_cache, system_fingerprint, *args, **kwargs
|
||||
model_provider,
|
||||
prompt_cache=prompt_cache,
|
||||
system_fingerprint=get_system_fingerprint(),
|
||||
*args,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
warnings.warn(
|
||||
|
@ -14,6 +14,7 @@ class DummyModelProvider:
|
||||
def __init__(self):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||
self.model_key = (HF_MODEL_PATH, None)
|
||||
|
||||
def load(self, model, adapter=None):
|
||||
assert model in ["default_model", "chat_model"]
|
||||
|
Loading…
Reference in New Issue
Block a user