mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Lazy loading models for faster convert and merge (#462)
This commit is contained in:
parent
8eee4399f4
commit
dc4f2e0a6b
@ -96,7 +96,7 @@ def convert(
|
|||||||
):
|
):
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
model_path = get_model_path(hf_path)
|
model_path = get_model_path(hf_path)
|
||||||
model, config, tokenizer = fetch_from_hub(model_path)
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||||
@ -110,7 +110,8 @@ def convert(
|
|||||||
if isinstance(mlx_path, str):
|
if isinstance(mlx_path, str):
|
||||||
mlx_path = Path(mlx_path)
|
mlx_path = Path(mlx_path)
|
||||||
|
|
||||||
save_weights(mlx_path, weights)
|
del model
|
||||||
|
save_weights(mlx_path, weights, donate_weights=True)
|
||||||
|
|
||||||
py_files = glob.glob(str(model_path / "*.py"))
|
py_files = glob.glob(str(model_path / "*.py"))
|
||||||
for file in py_files:
|
for file in py_files:
|
||||||
|
@ -118,10 +118,10 @@ def merge(
|
|||||||
# Load all models
|
# Load all models
|
||||||
base_hf_path = model_paths[0]
|
base_hf_path = model_paths[0]
|
||||||
base_path = get_model_path(base_hf_path)
|
base_path = get_model_path(base_hf_path)
|
||||||
base_model, base_config, tokenizer = fetch_from_hub(base_path)
|
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||||
models = []
|
models = []
|
||||||
for mp in model_paths[1:]:
|
for mp in model_paths[1:]:
|
||||||
model, config, _ = fetch_from_hub(get_model_path(mp))
|
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||||
base_type = base_config["model_type"]
|
base_type = base_config["model_type"]
|
||||||
model_type = config["model_type"]
|
model_type = config["model_type"]
|
||||||
if base_type != model_type:
|
if base_type != model_type:
|
||||||
@ -138,7 +138,8 @@ def merge(
|
|||||||
# Save base model
|
# Save base model
|
||||||
mlx_path = Path(mlx_path)
|
mlx_path = Path(mlx_path)
|
||||||
weights = dict(tree_flatten(base_model.parameters()))
|
weights = dict(tree_flatten(base_model.parameters()))
|
||||||
save_weights(mlx_path, weights)
|
del models, base_model
|
||||||
|
save_weights(mlx_path, weights, donate_weights=True)
|
||||||
py_files = glob.glob(str(base_path / "*.py"))
|
py_files = glob.glob(str(base_path / "*.py"))
|
||||||
for file in py_files:
|
for file in py_files:
|
||||||
shutil.copy(file, mlx_path)
|
shutil.copy(file, mlx_path)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
@ -254,12 +255,15 @@ def generate(
|
|||||||
return token_string
|
return token_string
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: Path) -> nn.Module:
|
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Load and initialize the model from a given path.
|
Load and initialize the model from a given path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (Path): The path to load the model from.
|
model_path (Path): The path to load the model from.
|
||||||
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
|
when needed. Default: ``False``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The loaded and initialized model.
|
nn.Module: The loaded and initialized model.
|
||||||
@ -315,6 +319,7 @@ def load_model(model_path: Path) -> nn.Module:
|
|||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
if not lazy:
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -322,7 +327,10 @@ def load_model(model_path: Path) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
path_or_hf_repo: str, tokenizer_config={}, adapter_file: str = None
|
path_or_hf_repo: str,
|
||||||
|
tokenizer_config={},
|
||||||
|
adapter_file: str = None,
|
||||||
|
lazy: bool = False,
|
||||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
@ -333,6 +341,9 @@ def load(
|
|||||||
Defaults to an empty dictionary.
|
Defaults to an empty dictionary.
|
||||||
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
|
when needed. Default: ``False``
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
||||||
|
|
||||||
@ -342,7 +353,7 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model = load_model(model_path)
|
model = load_model(model_path, lazy)
|
||||||
if adapter_file is not None:
|
if adapter_file is not None:
|
||||||
model = apply_lora_layers(model, adapter_file)
|
model = apply_lora_layers(model, adapter_file)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -352,9 +363,9 @@ def load(
|
|||||||
|
|
||||||
|
|
||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path,
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
|
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
|
||||||
model = load_model(model_path)
|
model = load_model(model_path, lazy)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
@ -431,7 +442,12 @@ response = generate(model, tokenizer, prompt="hello", verbose=True)
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
def save_weights(
|
||||||
|
save_path: Union[str, Path],
|
||||||
|
weights: Dict[str, Any],
|
||||||
|
*,
|
||||||
|
donate_weights: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Save model weights into specified directory."""
|
"""Save model weights into specified directory."""
|
||||||
if isinstance(save_path, str):
|
if isinstance(save_path, str):
|
||||||
save_path = Path(save_path)
|
save_path = Path(save_path)
|
||||||
@ -448,7 +464,15 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
|||||||
total_size = sum(v.nbytes for v in weights.values())
|
total_size = sum(v.nbytes for v in weights.values())
|
||||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||||
|
|
||||||
for i, shard in enumerate(shards):
|
# Write the weights and make sure no references are kept other than the
|
||||||
|
# necessary ones
|
||||||
|
if donate_weights:
|
||||||
|
weights.clear()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
for i in range(len(shards)):
|
||||||
|
shard = shards[i]
|
||||||
|
shards[i] = None
|
||||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||||
shard_path = save_path / shard_name
|
shard_path = save_path / shard_name
|
||||||
|
|
||||||
@ -456,6 +480,8 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
|||||||
|
|
||||||
for weight_name in shard.keys():
|
for weight_name in shard.keys():
|
||||||
index_data["weight_map"][weight_name] = shard_name
|
index_data["weight_map"][weight_name] = shard_name
|
||||||
|
del shard
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
index_data["weight_map"] = {
|
index_data["weight_map"] = {
|
||||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||||
|
Loading…
Reference in New Issue
Block a user