From 8b77677c05382568343d540a3e9f24155be1f06c Mon Sep 17 00:00:00 2001
From: Anchen
Date: Wed, 7 Feb 2024 00:32:15 +1100
Subject: [PATCH] chore(mlx-lm): add model weight index in save_weights (#413)
* chore(mlx-lm): add model weight index in save_weights
* Update llms/mlx_lm/utils.py
Co-authored-by: Awni Hannun
* Update llms/mlx_lm/utils.py
Co-authored-by: Awni Hannun
* chore: save total siZe as param size isntead of file size
* chore: clean up format
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/utils.py | 21 ++++++++++++++++++++-
1 file changed, 20 insertions(+), 1 deletion(-)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 61b8d9c0..32c9b7b4 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -389,6 +389,25 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
else "model.safetensors"
)
+ total_size = sum(v.nbytes for v in weights.values())
+ index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
+
for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)
- mx.save_safetensors(str(save_path / shard_name), shard)
+ shard_path = save_path / shard_name
+
+ mx.save_safetensors(str(shard_path), shard)
+
+ for weight_name in shard.keys():
+ index_data["weight_map"][weight_name] = shard_name
+
+ index_data["weight_map"] = {
+ k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
+ }
+
+ with open(save_path / "model.safetensors.index.json", "w") as f:
+ json.dump(
+ index_data,
+ f,
+ indent=4,
+ )