Add model_config parameter to load() and load_model() (#770)

* Add `model_config` parameter to `load()` and `load_model()`

For easy editing of the loaded model configuration (e.g., for changing RoPE theta or scaling of Phi-3 model)

Example:

```python
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed", model_config={"rope_theta":50000.0})
response = generate(model, tokenizer, prompt, max_tokens=MAX_TOKENS)
```

* Possible bug (default_loss)

* Revert "Possible bug (default_loss)"

This reverts commit 70a55ace18.

* Fix default_loss for lora

* 1. move load_model's new optional `model_config` arg to the end (fetch_from_hub()'s `model = load_model(model_path, lazy)`) 2. fix indentations (`black` hook)
This commit is contained in:
JosefAlbers 2024-05-11 02:13:34 +09:00 committed by GitHub
parent 6f0a69e682
commit 10853b57d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -299,7 +299,11 @@ def load_config(model_path: Path) -> dict:
return config
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load_model(
model_path: Path,
lazy: bool = False,
model_config: dict = {},
) -> nn.Module:
"""
Load and initialize the model from a given path.
@ -308,6 +312,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
model_config(dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
Returns:
nn.Module: The loaded and initialized model.
@ -318,6 +324,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
"""
config = load_config(model_path)
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors"))
@ -365,6 +372,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load(
path_or_hf_repo: str,
tokenizer_config={},
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
@ -375,6 +383,8 @@ def load(
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
@ -389,7 +399,7 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy)
model = load_model(model_path, lazy, model_config)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()