Adapters loading (#902)

* Added functionality to load in adapters through post-requests so you do not need to restart the server

* ran pre-commit

* nits

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Khush Gupta 2024-08-01 16:18:18 -07:00 committed by GitHub
parent 85dc76f6e0
commit 8fa12b0058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 7 deletions

View File

@ -78,3 +78,10 @@ curl localhost:8080/v1/chat/completions \
- `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive.
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
If the path is local is must be relative to the directory the server was
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.

View File

@ -97,8 +97,9 @@ class ModelProvider:
"Local models must be relative to the current working dir."
)
def load(self, model_path):
if self.model_key == model_path:
# Added in adapter_path to load dynamically
def load(self, model_path, adapter_path=None):
if self.model_key == (model_path, adapter_path):
return self.model, self.tokenizer
# Remove the old model if it exists.
@ -116,18 +117,22 @@ class ModelProvider:
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=self.cli_args.adapter_path,
adapter_path=(
adapter_path if adapter_path else self.cli_args.adapter_path
), # if the user doesn't change the model but adds an adapter path
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
model, tokenizer = load(
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = model_path
self.model_key = (model_path, adapter_path)
self.model = model
self.tokenizer = tokenizer
@ -193,6 +198,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream = self.body.get("stream", False)
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
@ -204,7 +210,9 @@ class APIHandler(BaseHTTPRequestHandler):
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
self.model, self.tokenizer = self.model_provider.load(
self.requested_model, self.adapter
)
except:
self._set_completion_headers(404)
self.end_headers()
@ -278,6 +286,8 @@ class APIHandler(BaseHTTPRequestHandler):
if not isinstance(self.requested_model, str):
raise ValueError("model must be a string")
if self.adapter is not None and not isinstance(self.adapter, str):
raise ValueError("adapter must be a string")
def generate_response(
self,

View File

@ -12,7 +12,7 @@ class DummyModelProvider:
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model):
def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer