mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fix tests
This commit is contained in:
parent
d85010bf4b
commit
d6222ae7ff
@ -30,6 +30,13 @@ from .models.cache import make_prompt_cache
|
|||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_fingerprint():
|
||||||
|
return (
|
||||||
|
f"{__version__}-{mx.__version__}-{platform.platform()}-"
|
||||||
|
f"{mx.metal.device_info().get('architecture', '')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StopCondition(NamedTuple):
|
class StopCondition(NamedTuple):
|
||||||
stop_met: bool
|
stop_met: bool
|
||||||
trim_length: int
|
trim_length: int
|
||||||
@ -180,9 +187,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_provider: ModelProvider,
|
model_provider: ModelProvider,
|
||||||
prompt_cache: PromptCache,
|
|
||||||
system_fingerprint: str,
|
|
||||||
*args,
|
*args,
|
||||||
|
prompt_cache: Optional[PromptCache] = None,
|
||||||
|
system_fingerprint: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -190,8 +197,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
"""
|
"""
|
||||||
self.created = int(time.time())
|
self.created = int(time.time())
|
||||||
self.model_provider = model_provider
|
self.model_provider = model_provider
|
||||||
self.prompt_cache = prompt_cache
|
self.prompt_cache = prompt_cache or PromptCache()
|
||||||
self.system_fingerprint = system_fingerprint
|
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _set_cors_headers(self):
|
def _set_cors_headers(self):
|
||||||
@ -725,14 +732,14 @@ def run(
|
|||||||
):
|
):
|
||||||
server_address = (host, port)
|
server_address = (host, port)
|
||||||
prompt_cache = PromptCache()
|
prompt_cache = PromptCache()
|
||||||
system_fingerprint = (
|
|
||||||
f"{__version__}-{mx.__version__}-{platform.platform()}-"
|
|
||||||
f"{mx.metal.device_info().get('architecture', '')}"
|
|
||||||
)
|
|
||||||
httpd = server_class(
|
httpd = server_class(
|
||||||
server_address,
|
server_address,
|
||||||
lambda *args, **kwargs: handler_class(
|
lambda *args, **kwargs: handler_class(
|
||||||
model_provider, prompt_cache, system_fingerprint, *args, **kwargs
|
model_provider,
|
||||||
|
prompt_cache=prompt_cache,
|
||||||
|
system_fingerprint=get_system_fingerprint(),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -14,6 +14,7 @@ class DummyModelProvider:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||||
|
self.model_key = (HF_MODEL_PATH, None)
|
||||||
|
|
||||||
def load(self, model, adapter=None):
|
def load(self, model, adapter=None):
|
||||||
assert model in ["default_model", "chat_model"]
|
assert model in ["default_model", "chat_model"]
|
||||||
|
Loading…
Reference in New Issue
Block a user