Metadata support for safetensors (#639)

* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
This commit is contained in:
Diogo
2024-02-08 22:33:15 -05:00
committed by GitHub
parent 221f8d3fc2
commit b57bd0488d
8 changed files with 108 additions and 69 deletions

View File

@@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") {
auto map = std::unordered_map<std::string, array>();
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
map.insert({"test2", ones({2, 2})});
save_safetensors(file_path, map);
auto dict = load_safetensors(file_path);
auto _metadata = std::unordered_map<std::string, std::string>();
_metadata.insert({"test", "test"});
_metadata.insert({"test2", "test2"});
save_safetensors(file_path, map, _metadata);
auto [dict, metadata] = load_safetensors(file_path);
CHECK_EQ(metadata, _metadata);
CHECK_EQ(dict.size(), 2);
CHECK_EQ(dict.count("test"), 1);
CHECK_EQ(dict.count("test2"), 1);
@@ -55,7 +61,7 @@ TEST_CASE("test gguf") {
}
// Test saving and loading string metadata
std::unordered_map<std::string, MetaData> original_metadata;
std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_str", "my string"});
save_gguf(file_path, original_weights, original_metadata);
@@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") {
// Scalar array
{
std::unordered_map<std::string, MetaData> original_metadata;
std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_arr", array(1.0)});
save_gguf(file_path, original_weights, original_metadata);
@@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") {
// 1D Array
{
std::unordered_map<std::string, MetaData> original_metadata;
std::unordered_map<std::string, GGUFMetaData> original_metadata;
auto arr = array({1.0, 2.0});
original_metadata.insert({"test_arr", arr});
save_gguf(file_path, original_weights, original_metadata);
@@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") {
// > 1D array throws
{
std::unordered_map<std::string, MetaData> original_metadata;
std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_arr", array({1.0}, {1, 1})});
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
}
// empty array throws
{
std::unordered_map<std::string, MetaData> original_metadata;
std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_arr", array({})});
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
}
// vector of string
{
std::unordered_map<std::string, MetaData> original_metadata;
std::unordered_map<std::string, GGUFMetaData> original_metadata;
std::vector<std::string> data = {"data1", "data2", "data1234"};
original_metadata.insert({"meta", data});
save_gguf(file_path, original_weights, original_metadata);
@@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") {
// vector of string, string, scalar, and array
{
std::unordered_map<std::string, MetaData> original_metadata;
std::unordered_map<std::string, GGUFMetaData> original_metadata;
std::vector<std::string> data = {"data1", "data2", "data1234"};
original_metadata.insert({"meta1", data});
original_metadata.insert({"meta2", array(2.5)});