From ab0f1dd1b6cf5737494dd5f91862b6503671e9d0 Mon Sep 17 00:00:00 2001 From: Alex Ishida Date: Thu, 29 Feb 2024 00:29:00 +0900 Subject: [PATCH] Add metadata when saving safetensors (#496) * Add metadata when saving safetensors Add metadata format="pt" for safetensors so that model's are accessible to `transformers` users as well. * save with metadata format mlx Save the model weights with metadata format of "mlx". * Updated llms/mlx_lm/generate.py --- llms/mlx_lm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index c3e0b191..ba22764c 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -502,7 +502,7 @@ def save_weights( shard_name = shard_file_format.format(i + 1, shards_count) shard_path = save_path / shard_name - mx.save_safetensors(str(shard_path), shard) + mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"}) for weight_name in shard.keys(): index_data["weight_map"][weight_name] = shard_name