diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 3f590f1c..e68d3af1 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -96,7 +96,7 @@ def convert( ): print("[INFO] Loading") model_path = get_model_path(hf_path) - model, config, tokenizer = fetch_from_hub(model_path) + model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) dtype = mx.float16 if quantize else getattr(mx, dtype) @@ -110,7 +110,8 @@ def convert( if isinstance(mlx_path, str): mlx_path = Path(mlx_path) - save_weights(mlx_path, weights) + del model + save_weights(mlx_path, weights, donate_weights=True) py_files = glob.glob(str(model_path / "*.py")) for file in py_files: diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py index 999d081e..2603653d 100644 --- a/llms/mlx_lm/merge.py +++ b/llms/mlx_lm/merge.py @@ -118,10 +118,10 @@ def merge( # Load all models base_hf_path = model_paths[0] base_path = get_model_path(base_hf_path) - base_model, base_config, tokenizer = fetch_from_hub(base_path) + base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True) models = [] for mp in model_paths[1:]: - model, config, _ = fetch_from_hub(get_model_path(mp)) + model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True) base_type = base_config["model_type"] model_type = config["model_type"] if base_type != model_type: @@ -138,7 +138,8 @@ def merge( # Save base model mlx_path = Path(mlx_path) weights = dict(tree_flatten(base_model.parameters())) - save_weights(mlx_path, weights) + del models, base_model + save_weights(mlx_path, weights, donate_weights=True) py_files = glob.glob(str(base_path / "*.py")) for file in py_files: shutil.copy(file, mlx_path) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 3afc8b85..e10a8a08 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,4 +1,5 @@ import copy +import gc import glob import importlib import json @@ -254,12 +255,15 @@ def generate( return token_string -def load_model(model_path: Path) -> nn.Module: +def load_model(model_path: Path, lazy: bool = False) -> nn.Module: """ Load and initialize the model from a given path. Args: model_path (Path): The path to load the model from. + lazy (bool): If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` Returns: nn.Module: The loaded and initialized model. @@ -315,14 +319,18 @@ def load_model(model_path: Path) -> nn.Module: model.load_weights(list(weights.items())) - mx.eval(model.parameters()) + if not lazy: + mx.eval(model.parameters()) model.eval() return model def load( - path_or_hf_repo: str, tokenizer_config={}, adapter_file: str = None + path_or_hf_repo: str, + tokenizer_config={}, + adapter_file: str = None, + lazy: bool = False, ) -> Tuple[nn.Module, PreTrainedTokenizer]: """ Load the model and tokenizer from a given path or a huggingface repository. @@ -333,6 +341,9 @@ def load( Defaults to an empty dictionary. adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model. Defaults to None. + lazy (bool): If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` Returns: Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. @@ -342,7 +353,7 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model = load_model(model_path) + model = load_model(model_path, lazy) if adapter_file is not None: model = apply_lora_layers(model, adapter_file) model.eval() @@ -352,9 +363,9 @@ def load( def fetch_from_hub( - model_path: Path, + model_path: Path, lazy: bool = False ) -> Tuple[Dict, dict, PreTrainedTokenizer]: - model = load_model(model_path) + model = load_model(model_path, lazy) config = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -431,7 +442,12 @@ response = generate(model, tokenizer, prompt="hello", verbose=True) ) -def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: +def save_weights( + save_path: Union[str, Path], + weights: Dict[str, Any], + *, + donate_weights: bool = False, +) -> None: """Save model weights into specified directory.""" if isinstance(save_path, str): save_path = Path(save_path) @@ -448,7 +464,15 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: total_size = sum(v.nbytes for v in weights.values()) index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} - for i, shard in enumerate(shards): + # Write the weights and make sure no references are kept other than the + # necessary ones + if donate_weights: + weights.clear() + gc.collect() + + for i in range(len(shards)): + shard = shards[i] + shards[i] = None shard_name = shard_file_format.format(i + 1, shards_count) shard_path = save_path / shard_name @@ -456,6 +480,8 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: for weight_name in shard.keys(): index_data["weight_map"][weight_name] = shard_name + del shard + gc.collect() index_data["weight_map"] = { k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])