mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Adapters loading (#902)
* Added functionality to load in adapters through post-requests so you do not need to restart the server * ran pre-commit * nits * fix test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -97,8 +97,9 @@ class ModelProvider:
|
||||
"Local models must be relative to the current working dir."
|
||||
)
|
||||
|
||||
def load(self, model_path):
|
||||
if self.model_key == model_path:
|
||||
# Added in adapter_path to load dynamically
|
||||
def load(self, model_path, adapter_path=None):
|
||||
if self.model_key == (model_path, adapter_path):
|
||||
return self.model, self.tokenizer
|
||||
|
||||
# Remove the old model if it exists.
|
||||
@@ -116,18 +117,22 @@ class ModelProvider:
|
||||
if model_path == "default_model" and self.cli_args.model is not None:
|
||||
model, tokenizer = load(
|
||||
self.cli_args.model,
|
||||
adapter_path=self.cli_args.adapter_path,
|
||||
adapter_path=(
|
||||
adapter_path if adapter_path else self.cli_args.adapter_path
|
||||
), # if the user doesn't change the model but adds an adapter path
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
else:
|
||||
self._validate_model_path(model_path)
|
||||
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
|
||||
model, tokenizer = load(
|
||||
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
if self.cli_args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
self.model_key = model_path
|
||||
self.model_key = (model_path, adapter_path)
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@@ -193,6 +198,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.stream = self.body.get("stream", False)
|
||||
self.stream_options = self.body.get("stream_options", None)
|
||||
self.requested_model = self.body.get("model", "default_model")
|
||||
self.adapter = self.body.get("adapters", None)
|
||||
self.max_tokens = self.body.get("max_tokens", 100)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
@@ -204,7 +210,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
# Load the model if needed
|
||||
try:
|
||||
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
|
||||
self.model, self.tokenizer = self.model_provider.load(
|
||||
self.requested_model, self.adapter
|
||||
)
|
||||
except:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
@@ -278,6 +286,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
if not isinstance(self.requested_model, str):
|
||||
raise ValueError("model must be a string")
|
||||
if self.adapter is not None and not isinstance(self.adapter, str):
|
||||
raise ValueError("adapter must be a string")
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user