diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 61b8d9c0..32c9b7b4 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -389,6 +389,25 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: else "model.safetensors" ) + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + for i, shard in enumerate(shards): shard_name = shard_file_format.format(i + 1, shards_count) - mx.save_safetensors(str(save_path / shard_name), shard) + shard_path = save_path / shard_name + + mx.save_safetensors(str(shard_path), shard) + + for weight_name in shard.keys(): + index_data["weight_map"][weight_name] = shard_name + + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + + with open(save_path / "model.safetensors.index.json", "w") as f: + json.dump( + index_data, + f, + indent=4, + )