Lazy loading models for faster convert and merge (#462)

This commit is contained in:
Angelos Katharopoulos
2024-02-20 13:36:55 -08:00
committed by GitHub
parent 8eee4399f4
commit dc4f2e0a6b
3 changed files with 41 additions and 13 deletions

View File

@@ -118,10 +118,10 @@ def merge(
# Load all models
base_hf_path = model_paths[0]
base_path = get_model_path(base_hf_path)
base_model, base_config, tokenizer = fetch_from_hub(base_path)
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
models = []
for mp in model_paths[1:]:
model, config, _ = fetch_from_hub(get_model_path(mp))
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
base_type = base_config["model_type"]
model_type = config["model_type"]
if base_type != model_type:
@@ -138,7 +138,8 @@ def merge(
# Save base model
mlx_path = Path(mlx_path)
weights = dict(tree_flatten(base_model.parameters()))
save_weights(mlx_path, weights)
del models, base_model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(base_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)