mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Adapters loading (#902)
* Added functionality to load in adapters through post-requests so you do not need to restart the server * ran pre-commit * nits * fix test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
85dc76f6e0
commit
8fa12b0058
@ -78,3 +78,10 @@ curl localhost:8080/v1/chat/completions \
|
||||
- `logprobs`: (Optional) An integer specifying the number of top tokens and
|
||||
corresponding log probabilities to return for each output in the generated
|
||||
sequence. If set, this can be any value between 1 and 10, inclusive.
|
||||
|
||||
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
|
||||
If the path is local is must be relative to the directory the server was
|
||||
started in.
|
||||
|
||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||
rlative to the directory the server was started in.
|
||||
|
@ -97,8 +97,9 @@ class ModelProvider:
|
||||
"Local models must be relative to the current working dir."
|
||||
)
|
||||
|
||||
def load(self, model_path):
|
||||
if self.model_key == model_path:
|
||||
# Added in adapter_path to load dynamically
|
||||
def load(self, model_path, adapter_path=None):
|
||||
if self.model_key == (model_path, adapter_path):
|
||||
return self.model, self.tokenizer
|
||||
|
||||
# Remove the old model if it exists.
|
||||
@ -116,18 +117,22 @@ class ModelProvider:
|
||||
if model_path == "default_model" and self.cli_args.model is not None:
|
||||
model, tokenizer = load(
|
||||
self.cli_args.model,
|
||||
adapter_path=self.cli_args.adapter_path,
|
||||
adapter_path=(
|
||||
adapter_path if adapter_path else self.cli_args.adapter_path
|
||||
), # if the user doesn't change the model but adds an adapter path
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
else:
|
||||
self._validate_model_path(model_path)
|
||||
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
|
||||
model, tokenizer = load(
|
||||
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
if self.cli_args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
self.model_key = model_path
|
||||
self.model_key = (model_path, adapter_path)
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@ -193,6 +198,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.stream = self.body.get("stream", False)
|
||||
self.stream_options = self.body.get("stream_options", None)
|
||||
self.requested_model = self.body.get("model", "default_model")
|
||||
self.adapter = self.body.get("adapters", None)
|
||||
self.max_tokens = self.body.get("max_tokens", 100)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
@ -204,7 +210,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
# Load the model if needed
|
||||
try:
|
||||
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
|
||||
self.model, self.tokenizer = self.model_provider.load(
|
||||
self.requested_model, self.adapter
|
||||
)
|
||||
except:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
@ -278,6 +286,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
if not isinstance(self.requested_model, str):
|
||||
raise ValueError("model must be a string")
|
||||
if self.adapter is not None and not isinstance(self.adapter, str):
|
||||
raise ValueError("adapter must be a string")
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
|
@ -12,7 +12,7 @@ class DummyModelProvider:
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
def load(self, model):
|
||||
def load(self, model, adapter=None):
|
||||
assert model in ["default_model", "chat_model"]
|
||||
return self.model, self.tokenizer
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user