From 8fa12b0058a2647dff5d776f84deef43e1cb7720 Mon Sep 17 00:00:00 2001 From: Khush Gupta <78624519+khushgx@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:18:18 -0700 Subject: [PATCH] Adapters loading (#902) * Added functionality to load in adapters through post-requests so you do not need to restart the server * ran pre-commit * nits * fix test --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/SERVER.md | 7 +++++++ llms/mlx_lm/server.py | 22 ++++++++++++++++------ llms/tests/test_server.py | 2 +- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 48364bee..9c42d410 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -78,3 +78,10 @@ curl localhost:8080/v1/chat/completions \ - `logprobs`: (Optional) An integer specifying the number of top tokens and corresponding log probabilities to return for each output in the generated sequence. If set, this can be any value between 1 and 10, inclusive. + +- `model`: (Optional) A string path to a local model or Hugging Face repo id. + If the path is local is must be relative to the directory the server was + started in. + +- `adapters`: (Optional) A string path to low-rank adapters. The path must be + rlative to the directory the server was started in. diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index c13878f3..7456399c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -97,8 +97,9 @@ class ModelProvider: "Local models must be relative to the current working dir." ) - def load(self, model_path): - if self.model_key == model_path: + # Added in adapter_path to load dynamically + def load(self, model_path, adapter_path=None): + if self.model_key == (model_path, adapter_path): return self.model, self.tokenizer # Remove the old model if it exists. @@ -116,18 +117,22 @@ class ModelProvider: if model_path == "default_model" and self.cli_args.model is not None: model, tokenizer = load( self.cli_args.model, - adapter_path=self.cli_args.adapter_path, + adapter_path=( + adapter_path if adapter_path else self.cli_args.adapter_path + ), # if the user doesn't change the model but adds an adapter path tokenizer_config=tokenizer_config, ) else: self._validate_model_path(model_path) - model, tokenizer = load(model_path, tokenizer_config=tokenizer_config) + model, tokenizer = load( + model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config + ) if self.cli_args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - self.model_key = model_path + self.model_key = (model_path, adapter_path) self.model = model self.tokenizer = tokenizer @@ -193,6 +198,7 @@ class APIHandler(BaseHTTPRequestHandler): self.stream = self.body.get("stream", False) self.stream_options = self.body.get("stream_options", None) self.requested_model = self.body.get("model", "default_model") + self.adapter = self.body.get("adapters", None) self.max_tokens = self.body.get("max_tokens", 100) self.temperature = self.body.get("temperature", 1.0) self.top_p = self.body.get("top_p", 1.0) @@ -204,7 +210,9 @@ class APIHandler(BaseHTTPRequestHandler): # Load the model if needed try: - self.model, self.tokenizer = self.model_provider.load(self.requested_model) + self.model, self.tokenizer = self.model_provider.load( + self.requested_model, self.adapter + ) except: self._set_completion_headers(404) self.end_headers() @@ -278,6 +286,8 @@ class APIHandler(BaseHTTPRequestHandler): if not isinstance(self.requested_model, str): raise ValueError("model must be a string") + if self.adapter is not None and not isinstance(self.adapter, str): + raise ValueError("adapter must be a string") def generate_response( self, diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index 4d71a5a3..b8047eaa 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -12,7 +12,7 @@ class DummyModelProvider: HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) - def load(self, model): + def load(self, model, adapter=None): assert model in ["default_model", "chat_model"] return self.model, self.tokenizer