Lazy loading models for faster convert and merge (#462)

This commit is contained in:
Angelos Katharopoulos 2024-02-20 13:36:55 -08:00 committed by GitHub
parent 8eee4399f4
commit dc4f2e0a6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 13 deletions

View File

@ -96,7 +96,7 @@ def convert(
): ):
print("[INFO] Loading") print("[INFO] Loading")
model_path = get_model_path(hf_path) model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path) model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters())) weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype) dtype = mx.float16 if quantize else getattr(mx, dtype)
@ -110,7 +110,8 @@ def convert(
if isinstance(mlx_path, str): if isinstance(mlx_path, str):
mlx_path = Path(mlx_path) mlx_path = Path(mlx_path)
save_weights(mlx_path, weights) del model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py")) py_files = glob.glob(str(model_path / "*.py"))
for file in py_files: for file in py_files:

View File

@ -118,10 +118,10 @@ def merge(
# Load all models # Load all models
base_hf_path = model_paths[0] base_hf_path = model_paths[0]
base_path = get_model_path(base_hf_path) base_path = get_model_path(base_hf_path)
base_model, base_config, tokenizer = fetch_from_hub(base_path) base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
models = [] models = []
for mp in model_paths[1:]: for mp in model_paths[1:]:
model, config, _ = fetch_from_hub(get_model_path(mp)) model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
base_type = base_config["model_type"] base_type = base_config["model_type"]
model_type = config["model_type"] model_type = config["model_type"]
if base_type != model_type: if base_type != model_type:
@ -138,7 +138,8 @@ def merge(
# Save base model # Save base model
mlx_path = Path(mlx_path) mlx_path = Path(mlx_path)
weights = dict(tree_flatten(base_model.parameters())) weights = dict(tree_flatten(base_model.parameters()))
save_weights(mlx_path, weights) del models, base_model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(base_path / "*.py")) py_files = glob.glob(str(base_path / "*.py"))
for file in py_files: for file in py_files:
shutil.copy(file, mlx_path) shutil.copy(file, mlx_path)

View File

@ -1,4 +1,5 @@
import copy import copy
import gc
import glob import glob
import importlib import importlib
import json import json
@ -254,12 +255,15 @@ def generate(
return token_string return token_string
def load_model(model_path: Path) -> nn.Module: def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
""" """
Load and initialize the model from a given path. Load and initialize the model from a given path.
Args: Args:
model_path (Path): The path to load the model from. model_path (Path): The path to load the model from.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns: Returns:
nn.Module: The loaded and initialized model. nn.Module: The loaded and initialized model.
@ -315,6 +319,7 @@ def load_model(model_path: Path) -> nn.Module:
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters()) mx.eval(model.parameters())
model.eval() model.eval()
@ -322,7 +327,10 @@ def load_model(model_path: Path) -> nn.Module:
def load( def load(
path_or_hf_repo: str, tokenizer_config={}, adapter_file: str = None path_or_hf_repo: str,
tokenizer_config={},
adapter_file: str = None,
lazy: bool = False,
) -> Tuple[nn.Module, PreTrainedTokenizer]: ) -> Tuple[nn.Module, PreTrainedTokenizer]:
""" """
Load the model and tokenizer from a given path or a huggingface repository. Load the model and tokenizer from a given path or a huggingface repository.
@ -333,6 +341,9 @@ def load(
Defaults to an empty dictionary. Defaults to an empty dictionary.
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model. adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
Defaults to None. Defaults to None.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns: Returns:
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
@ -342,7 +353,7 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path) model = load_model(model_path, lazy)
if adapter_file is not None: if adapter_file is not None:
model = apply_lora_layers(model, adapter_file) model = apply_lora_layers(model, adapter_file)
model.eval() model.eval()
@ -352,9 +363,9 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, model_path: Path, lazy: bool = False
) -> Tuple[Dict, dict, PreTrainedTokenizer]: ) -> Tuple[Dict, dict, PreTrainedTokenizer]:
model = load_model(model_path) model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
@ -431,7 +442,12 @@ response = generate(model, tokenizer, prompt="hello", verbose=True)
) )
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: def save_weights(
save_path: Union[str, Path],
weights: Dict[str, Any],
*,
donate_weights: bool = False,
) -> None:
"""Save model weights into specified directory.""" """Save model weights into specified directory."""
if isinstance(save_path, str): if isinstance(save_path, str):
save_path = Path(save_path) save_path = Path(save_path)
@ -448,7 +464,15 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
total_size = sum(v.nbytes for v in weights.values()) total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
for i, shard in enumerate(shards): # Write the weights and make sure no references are kept other than the
# necessary ones
if donate_weights:
weights.clear()
gc.collect()
for i in range(len(shards)):
shard = shards[i]
shards[i] = None
shard_name = shard_file_format.format(i + 1, shards_count) shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name shard_path = save_path / shard_name
@ -456,6 +480,8 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
for weight_name in shard.keys(): for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name index_data["weight_map"][weight_name] = shard_name
del shard
gc.collect()
index_data["weight_map"] = { index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])