Add model_config parameter to load() and load_model()

For easy editing of the loaded model configuration (e.g., for changing RoPE theta or scaling of Phi-3 model)

Example:

```python
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed", model_config={"rope_theta":50000.0})
response = generate(model, tokenizer, prompt, max_tokens=MAX_TOKENS)
```
This commit is contained in:
JosefAlbers 2024-05-10 16:24:16 +09:00 committed by GitHub
parent fad9598372
commit a9192f81b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -299,12 +299,14 @@ def load_config(model_path: Path) -> dict:
return config
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load_model(model_path: Path, model_config: dict = {}, lazy: bool = False) -> nn.Module:
"""
Load and initialize the model from a given path.
Args:
model_path (Path): The path to load the model from.
model_config(dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
@ -318,6 +320,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
"""
config = load_config(model_path)
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors"))
@ -365,6 +368,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load(
path_or_hf_repo: str,
tokenizer_config={},
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
@ -375,6 +379,8 @@ def load(
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
@ -389,7 +395,7 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy)
model = load_model(model_path, model_config, lazy)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()