mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
chore(mlx-lm): add model weight index in save_weights (#413)
* chore(mlx-lm): add model weight index in save_weights * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: save total siZe as param size isntead of file size * chore: clean up format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
a7d139f484
commit
8b77677c05
@ -389,6 +389,25 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||
else "model.safetensors"
|
||||
)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
mx.save_safetensors(str(save_path / shard_name), shard)
|
||||
shard_path = save_path / shard_name
|
||||
|
||||
mx.save_safetensors(str(shard_path), shard)
|
||||
|
||||
for weight_name in shard.keys():
|
||||
index_data["weight_map"][weight_name] = shard_name
|
||||
|
||||
index_data["weight_map"] = {
|
||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||
}
|
||||
|
||||
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(
|
||||
index_data,
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user