Adapters loading (#902)

* Added functionality to load in adapters through post-requests so you do not need to restart the server

* ran pre-commit

* nits

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Khush Gupta 2024-08-01 16:18:18 -07:00 committed by GitHub
parent 85dc76f6e0
commit 8fa12b0058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 7 deletions

View File

@ -78,3 +78,10 @@ curl localhost:8080/v1/chat/completions \
- `logprobs`: (Optional) An integer specifying the number of top tokens and - `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive. sequence. If set, this can be any value between 1 and 10, inclusive.
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
If the path is local is must be relative to the directory the server was
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.

View File

@ -97,8 +97,9 @@ class ModelProvider:
"Local models must be relative to the current working dir." "Local models must be relative to the current working dir."
) )
def load(self, model_path): # Added in adapter_path to load dynamically
if self.model_key == model_path: def load(self, model_path, adapter_path=None):
if self.model_key == (model_path, adapter_path):
return self.model, self.tokenizer return self.model, self.tokenizer
# Remove the old model if it exists. # Remove the old model if it exists.
@ -116,18 +117,22 @@ class ModelProvider:
if model_path == "default_model" and self.cli_args.model is not None: if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load( model, tokenizer = load(
self.cli_args.model, self.cli_args.model,
adapter_path=self.cli_args.adapter_path, adapter_path=(
adapter_path if adapter_path else self.cli_args.adapter_path
), # if the user doesn't change the model but adds an adapter path
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
) )
else: else:
self._validate_model_path(model_path) self._validate_model_path(model_path)
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config) model, tokenizer = load(
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
)
if self.cli_args.use_default_chat_template: if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = model_path self.model_key = (model_path, adapter_path)
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -193,6 +198,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream = self.body.get("stream", False) self.stream = self.body.get("stream", False)
self.stream_options = self.body.get("stream_options", None) self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model") self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100) self.max_tokens = self.body.get("max_tokens", 100)
self.temperature = self.body.get("temperature", 1.0) self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0) self.top_p = self.body.get("top_p", 1.0)
@ -204,7 +210,9 @@ class APIHandler(BaseHTTPRequestHandler):
# Load the model if needed # Load the model if needed
try: try:
self.model, self.tokenizer = self.model_provider.load(self.requested_model) self.model, self.tokenizer = self.model_provider.load(
self.requested_model, self.adapter
)
except: except:
self._set_completion_headers(404) self._set_completion_headers(404)
self.end_headers() self.end_headers()
@ -278,6 +286,8 @@ class APIHandler(BaseHTTPRequestHandler):
if not isinstance(self.requested_model, str): if not isinstance(self.requested_model, str):
raise ValueError("model must be a string") raise ValueError("model must be a string")
if self.adapter is not None and not isinstance(self.adapter, str):
raise ValueError("adapter must be a string")
def generate_response( def generate_response(
self, self,

View File

@ -12,7 +12,7 @@ class DummyModelProvider:
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH) self.model, self.tokenizer = load(HF_MODEL_PATH)
def load(self, model): def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"] assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer return self.model, self.tokenizer