fix tests

This commit is contained in:
Awni Hannun 2024-10-09 13:17:56 -07:00
parent d85010bf4b
commit d6222ae7ff
2 changed files with 17 additions and 9 deletions

View File

@ -30,6 +30,13 @@ from .models.cache import make_prompt_cache
from .utils import generate_step, load from .utils import generate_step, load
def get_system_fingerprint():
return (
f"{__version__}-{mx.__version__}-{platform.platform()}-"
f"{mx.metal.device_info().get('architecture', '')}"
)
class StopCondition(NamedTuple): class StopCondition(NamedTuple):
stop_met: bool stop_met: bool
trim_length: int trim_length: int
@ -180,9 +187,9 @@ class APIHandler(BaseHTTPRequestHandler):
def __init__( def __init__(
self, self,
model_provider: ModelProvider, model_provider: ModelProvider,
prompt_cache: PromptCache,
system_fingerprint: str,
*args, *args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs, **kwargs,
): ):
""" """
@ -190,8 +197,8 @@ class APIHandler(BaseHTTPRequestHandler):
""" """
self.created = int(time.time()) self.created = int(time.time())
self.model_provider = model_provider self.model_provider = model_provider
self.prompt_cache = prompt_cache self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _set_cors_headers(self): def _set_cors_headers(self):
@ -725,14 +732,14 @@ def run(
): ):
server_address = (host, port) server_address = (host, port)
prompt_cache = PromptCache() prompt_cache = PromptCache()
system_fingerprint = (
f"{__version__}-{mx.__version__}-{platform.platform()}-"
f"{mx.metal.device_info().get('architecture', '')}"
)
httpd = server_class( httpd = server_class(
server_address, server_address,
lambda *args, **kwargs: handler_class( lambda *args, **kwargs: handler_class(
model_provider, prompt_cache, system_fingerprint, *args, **kwargs model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
), ),
) )
warnings.warn( warnings.warn(

View File

@ -14,6 +14,7 @@ class DummyModelProvider:
def __init__(self): def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH) self.model, self.tokenizer = load(HF_MODEL_PATH)
self.model_key = (HF_MODEL_PATH, None)
def load(self, model, adapter=None): def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"] assert model in ["default_model", "chat_model"]