Add metadata when saving safetensors (#496)

* Add metadata when saving safetensors

Add metadata format="pt" for safetensors so that model's are accessible to `transformers` users as well.

* save with metadata format mlx

Save the model weights with metadata format of "mlx".

* Updated llms/mlx_lm/generate.py
This commit is contained in:
Alex Ishida 2024-02-29 00:29:00 +09:00 committed by GitHub
parent ea92f623d6
commit ab0f1dd1b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -502,7 +502,7 @@ def save_weights(
shard_name = shard_file_format.format(i + 1, shards_count) shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name shard_path = save_path / shard_name
mx.save_safetensors(str(shard_path), shard) mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
for weight_name in shard.keys(): for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name index_data["weight_map"][weight_name] = shard_name