mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-23 14:08:07 +08:00 
			
		
		
		
	Add model_config parameter to load() and load_model() (#770)
				
					
				
			* Add `model_config` parameter to `load()` and `load_model()`
For easy editing of the loaded model configuration (e.g., for changing RoPE theta or scaling of Phi-3 model)
Example:
```python
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed", model_config={"rope_theta":50000.0})
response = generate(model, tokenizer, prompt, max_tokens=MAX_TOKENS)
```
* Possible bug (default_loss)
* Revert "Possible bug (default_loss)"
This reverts commit 70a55ace18.
* Fix default_loss for lora
* 1. move load_model's new optional `model_config` arg to the end (fetch_from_hub()'s `model = load_model(model_path, lazy)`) 2. fix indentations (`black` hook)
			
			
This commit is contained in:
		| @@ -299,7 +299,11 @@ def load_config(model_path: Path) -> dict: | |||||||
|     return config |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_model(model_path: Path, lazy: bool = False) -> nn.Module: | def load_model( | ||||||
|  |     model_path: Path, | ||||||
|  |     lazy: bool = False, | ||||||
|  |     model_config: dict = {}, | ||||||
|  | ) -> nn.Module: | ||||||
|     """ |     """ | ||||||
|     Load and initialize the model from a given path. |     Load and initialize the model from a given path. | ||||||
|  |  | ||||||
| @@ -308,6 +312,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: | |||||||
|         lazy (bool): If False eval the model parameters to make sure they are |         lazy (bool): If False eval the model parameters to make sure they are | ||||||
|             loaded in memory before returning, otherwise they will be loaded |             loaded in memory before returning, otherwise they will be loaded | ||||||
|             when needed. Default: ``False`` |             when needed. Default: ``False`` | ||||||
|  |         model_config(dict, optional): Configuration parameters for the model. | ||||||
|  |             Defaults to an empty dictionary. | ||||||
|  |  | ||||||
|     Returns: |     Returns: | ||||||
|         nn.Module: The loaded and initialized model. |         nn.Module: The loaded and initialized model. | ||||||
| @@ -318,6 +324,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     config = load_config(model_path) |     config = load_config(model_path) | ||||||
|  |     config.update(model_config) | ||||||
|  |  | ||||||
|     weight_files = glob.glob(str(model_path / "model*.safetensors")) |     weight_files = glob.glob(str(model_path / "model*.safetensors")) | ||||||
|  |  | ||||||
| @@ -365,6 +372,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: | |||||||
| def load( | def load( | ||||||
|     path_or_hf_repo: str, |     path_or_hf_repo: str, | ||||||
|     tokenizer_config={}, |     tokenizer_config={}, | ||||||
|  |     model_config={}, | ||||||
|     adapter_path: Optional[str] = None, |     adapter_path: Optional[str] = None, | ||||||
|     lazy: bool = False, |     lazy: bool = False, | ||||||
| ) -> Tuple[nn.Module, TokenizerWrapper]: | ) -> Tuple[nn.Module, TokenizerWrapper]: | ||||||
| @@ -375,6 +383,8 @@ def load( | |||||||
|         path_or_hf_repo (Path): The path or the huggingface repository to load the model from. |         path_or_hf_repo (Path): The path or the huggingface repository to load the model from. | ||||||
|         tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. |         tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. | ||||||
|             Defaults to an empty dictionary. |             Defaults to an empty dictionary. | ||||||
|  |         model_config(dict, optional): Configuration parameters specifically for the model. | ||||||
|  |             Defaults to an empty dictionary. | ||||||
|         adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers |         adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers | ||||||
|             to the model. Default: ``None``. |             to the model. Default: ``None``. | ||||||
|         lazy (bool): If False eval the model parameters to make sure they are |         lazy (bool): If False eval the model parameters to make sure they are | ||||||
| @@ -389,7 +399,7 @@ def load( | |||||||
|     """ |     """ | ||||||
|     model_path = get_model_path(path_or_hf_repo) |     model_path = get_model_path(path_or_hf_repo) | ||||||
|  |  | ||||||
|     model = load_model(model_path, lazy) |     model = load_model(model_path, lazy, model_config) | ||||||
|     if adapter_path is not None: |     if adapter_path is not None: | ||||||
|         model = apply_lora_layers(model, adapter_path) |         model = apply_lora_layers(model, adapter_path) | ||||||
|         model.eval() |         model.eval() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 JosefAlbers
					JosefAlbers