mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Metadata support for safetensors (#639)
* metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests
This commit is contained in:
@@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") {
|
||||
auto map = std::unordered_map<std::string, array>();
|
||||
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||
map.insert({"test2", ones({2, 2})});
|
||||
save_safetensors(file_path, map);
|
||||
auto dict = load_safetensors(file_path);
|
||||
auto _metadata = std::unordered_map<std::string, std::string>();
|
||||
_metadata.insert({"test", "test"});
|
||||
_metadata.insert({"test2", "test2"});
|
||||
save_safetensors(file_path, map, _metadata);
|
||||
auto [dict, metadata] = load_safetensors(file_path);
|
||||
|
||||
CHECK_EQ(metadata, _metadata);
|
||||
|
||||
CHECK_EQ(dict.size(), 2);
|
||||
CHECK_EQ(dict.count("test"), 1);
|
||||
CHECK_EQ(dict.count("test2"), 1);
|
||||
@@ -55,7 +61,7 @@ TEST_CASE("test gguf") {
|
||||
}
|
||||
|
||||
// Test saving and loading string metadata
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||
original_metadata.insert({"test_str", "my string"});
|
||||
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
@@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") {
|
||||
|
||||
// Scalar array
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||
original_metadata.insert({"test_arr", array(1.0)});
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
|
||||
@@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") {
|
||||
|
||||
// 1D Array
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||
auto arr = array({1.0, 2.0});
|
||||
original_metadata.insert({"test_arr", arr});
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
@@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") {
|
||||
|
||||
// > 1D array throws
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||
original_metadata.insert({"test_arr", array({1.0}, {1, 1})});
|
||||
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
||||
}
|
||||
|
||||
// empty array throws
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||
original_metadata.insert({"test_arr", array({})});
|
||||
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
||||
}
|
||||
|
||||
// vector of string
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
||||
original_metadata.insert({"meta", data});
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
@@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") {
|
||||
|
||||
// vector of string, string, scalar, and array
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
||||
original_metadata.insert({"meta1", data});
|
||||
original_metadata.insert({"meta2", array(2.5)});
|
||||
|
Reference in New Issue
Block a user