From 8b77677c05382568343d540a3e9f24155be1f06c Mon Sep 17 00:00:00 2001 From: Anchen Date: Wed, 7 Feb 2024 00:32:15 +1100 Subject: [PATCH] chore(mlx-lm): add model weight index in save_weights (#413) * chore(mlx-lm): add model weight index in save_weights * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun * chore: save total siZe as param size isntead of file size * chore: clean up format --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 61b8d9c0..32c9b7b4 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -389,6 +389,25 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: else "model.safetensors" ) + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + for i, shard in enumerate(shards): shard_name = shard_file_format.format(i + 1, shards_count) - mx.save_safetensors(str(save_path / shard_name), shard) + shard_path = save_path / shard_name + + mx.save_safetensors(str(shard_path), shard) + + for weight_name in shard.keys(): + index_data["weight_map"][weight_name] = shard_name + + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + + with open(save_path / "model.safetensors.index.json", "w") as f: + json.dump( + index_data, + f, + indent=4, + )