mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Lazy loading models for faster convert and merge (#462)
This commit is contained in:

committed by
GitHub

parent
8eee4399f4
commit
dc4f2e0a6b
@@ -118,10 +118,10 @@ def merge(
|
||||
# Load all models
|
||||
base_hf_path = model_paths[0]
|
||||
base_path = get_model_path(base_hf_path)
|
||||
base_model, base_config, tokenizer = fetch_from_hub(base_path)
|
||||
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||
models = []
|
||||
for mp in model_paths[1:]:
|
||||
model, config, _ = fetch_from_hub(get_model_path(mp))
|
||||
model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
base_type = base_config["model_type"]
|
||||
model_type = config["model_type"]
|
||||
if base_type != model_type:
|
||||
@@ -138,7 +138,8 @@ def merge(
|
||||
# Save base model
|
||||
mlx_path = Path(mlx_path)
|
||||
weights = dict(tree_flatten(base_model.parameters()))
|
||||
save_weights(mlx_path, weights)
|
||||
del models, base_model
|
||||
save_weights(mlx_path, weights, donate_weights=True)
|
||||
py_files = glob.glob(str(base_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
Reference in New Issue
Block a user