Add model_config parameter to load() and load_model()

For easy editing of the loaded model configuration (e.g., for changing RoPE theta or scaling of Phi-3 model)

Example:

```python
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed", model_config={"rope_theta":50000.0})
response = generate(model, tokenizer, prompt, max_tokens=MAX_TOKENS)
```
This commit is contained in:
JosefAlbers 2024-05-10 16:24:16 +09:00 committed by GitHub
parent fad9598372
commit a9192f81b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -299,12 +299,14 @@ def load_config(model_path: Path) -> dict:
return config return config
def load_model(model_path: Path, lazy: bool = False) -> nn.Module: def load_model(model_path: Path, model_config: dict = {}, lazy: bool = False) -> nn.Module:
""" """
Load and initialize the model from a given path. Load and initialize the model from a given path.
Args: Args:
model_path (Path): The path to load the model from. model_path (Path): The path to load the model from.
model_config(dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
@ -318,6 +320,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
""" """
config = load_config(model_path) config = load_config(model_path)
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors")) weight_files = glob.glob(str(model_path / "model*.safetensors"))
@ -365,6 +368,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load( def load(
path_or_hf_repo: str, path_or_hf_repo: str,
tokenizer_config={}, tokenizer_config={},
model_config={},
adapter_path: Optional[str] = None, adapter_path: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]: ) -> Tuple[nn.Module, TokenizerWrapper]:
@ -375,6 +379,8 @@ def load(
path_or_hf_repo (Path): The path or the huggingface repository to load the model from. path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary. Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``. to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
@ -389,7 +395,7 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy) model = load_model(model_path, model_config, lazy)
if adapter_path is not None: if adapter_path is not None:
model = apply_lora_layers(model, adapter_path) model = apply_lora_layers(model, adapter_path)
model.eval() model.eval()