mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add model_config
parameter to load()
and load_model()
(#770)
* Add `model_config` parameter to `load()` and `load_model()`
For easy editing of the loaded model configuration (e.g., for changing RoPE theta or scaling of Phi-3 model)
Example:
```python
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed", model_config={"rope_theta":50000.0})
response = generate(model, tokenizer, prompt, max_tokens=MAX_TOKENS)
```
* Possible bug (default_loss)
* Revert "Possible bug (default_loss)"
This reverts commit 70a55ace18
.
* Fix default_loss for lora
* 1. move load_model's new optional `model_config` arg to the end (fetch_from_hub()'s `model = load_model(model_path, lazy)`) 2. fix indentations (`black` hook)
This commit is contained in:
parent
6f0a69e682
commit
10853b57d9
@ -299,7 +299,11 @@ def load_config(model_path: Path) -> dict:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
def load_model(
|
||||||
|
model_path: Path,
|
||||||
|
lazy: bool = False,
|
||||||
|
model_config: dict = {},
|
||||||
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Load and initialize the model from a given path.
|
Load and initialize the model from a given path.
|
||||||
|
|
||||||
@ -308,6 +312,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
|
model_config(dict, optional): Configuration parameters for the model.
|
||||||
|
Defaults to an empty dictionary.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The loaded and initialized model.
|
nn.Module: The loaded and initialized model.
|
||||||
@ -318,6 +324,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config = load_config(model_path)
|
config = load_config(model_path)
|
||||||
|
config.update(model_config)
|
||||||
|
|
||||||
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||||
|
|
||||||
@ -365,6 +372,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
def load(
|
def load(
|
||||||
path_or_hf_repo: str,
|
path_or_hf_repo: str,
|
||||||
tokenizer_config={},
|
tokenizer_config={},
|
||||||
|
model_config={},
|
||||||
adapter_path: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||||
@ -375,6 +383,8 @@ def load(
|
|||||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||||
Defaults to an empty dictionary.
|
Defaults to an empty dictionary.
|
||||||
|
model_config(dict, optional): Configuration parameters specifically for the model.
|
||||||
|
Defaults to an empty dictionary.
|
||||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||||
to the model. Default: ``None``.
|
to the model. Default: ``None``.
|
||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
@ -389,7 +399,7 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model = load_model(model_path, lazy)
|
model = load_model(model_path, lazy, model_config)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = apply_lora_layers(model, adapter_path)
|
model = apply_lora_layers(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user