mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
chore(mlx-lm): add model weight index in save_weights (#413)
* chore(mlx-lm): add model weight index in save_weights * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: save total siZe as param size isntead of file size * chore: clean up format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
a7d139f484
commit
8b77677c05
@ -389,6 +389,25 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
|||||||
else "model.safetensors"
|
else "model.safetensors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
total_size = sum(v.nbytes for v in weights.values())
|
||||||
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||||
|
|
||||||
for i, shard in enumerate(shards):
|
for i, shard in enumerate(shards):
|
||||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||||
mx.save_safetensors(str(save_path / shard_name), shard)
|
shard_path = save_path / shard_name
|
||||||
|
|
||||||
|
mx.save_safetensors(str(shard_path), shard)
|
||||||
|
|
||||||
|
for weight_name in shard.keys():
|
||||||
|
index_data["weight_map"][weight_name] = shard_name
|
||||||
|
|
||||||
|
index_data["weight_map"] = {
|
||||||
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||||
|
json.dump(
|
||||||
|
index_data,
|
||||||
|
f,
|
||||||
|
indent=4,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user