Save lora config (#636)

* lora config

* comments

* version bump
This commit is contained in:
Awni Hannun
2024-04-02 13:52:53 -07:00
committed by GitHub
parent d661440dbb
commit 2bd64b78cf
10 changed files with 73 additions and 90 deletions

View File

@@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -354,7 +354,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
def load(
path_or_hf_repo: str,
tokenizer_config={},
adapter_file: Optional[str] = None,
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
@@ -364,8 +364,8 @@ def load(
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
Defaults to None.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
@@ -379,8 +379,8 @@ def load(
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy)
if adapter_file is not None:
model = apply_lora_layers(model, adapter_file)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)