From d6222ae7ff2ddf662a5e79ca63bd8bcaacb36ef9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 9 Oct 2024 13:17:56 -0700 Subject: [PATCH] fix tests --- llms/mlx_lm/server.py | 25 ++++++++++++++++--------- llms/tests/test_server.py | 1 + 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index f87c1d00..47a72599 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -30,6 +30,13 @@ from .models.cache import make_prompt_cache from .utils import generate_step, load +def get_system_fingerprint(): + return ( + f"{__version__}-{mx.__version__}-{platform.platform()}-" + f"{mx.metal.device_info().get('architecture', '')}" + ) + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -180,9 +187,9 @@ class APIHandler(BaseHTTPRequestHandler): def __init__( self, model_provider: ModelProvider, - prompt_cache: PromptCache, - system_fingerprint: str, *args, + prompt_cache: Optional[PromptCache] = None, + system_fingerprint: Optional[str] = None, **kwargs, ): """ @@ -190,8 +197,8 @@ class APIHandler(BaseHTTPRequestHandler): """ self.created = int(time.time()) self.model_provider = model_provider - self.prompt_cache = prompt_cache - self.system_fingerprint = system_fingerprint + self.prompt_cache = prompt_cache or PromptCache() + self.system_fingerprint = system_fingerprint or get_system_fingerprint() super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -725,14 +732,14 @@ def run( ): server_address = (host, port) prompt_cache = PromptCache() - system_fingerprint = ( - f"{__version__}-{mx.__version__}-{platform.platform()}-" - f"{mx.metal.device_info().get('architecture', '')}" - ) httpd = server_class( server_address, lambda *args, **kwargs: handler_class( - model_provider, prompt_cache, system_fingerprint, *args, **kwargs + model_provider, + prompt_cache=prompt_cache, + system_fingerprint=get_system_fingerprint(), + *args, + **kwargs, ), ) warnings.warn( diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index cbcccfbe..ad17554d 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -14,6 +14,7 @@ class DummyModelProvider: def __init__(self): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) + self.model_key = (HF_MODEL_PATH, None) def load(self, model, adapter=None): assert model in ["default_model", "chat_model"]