updated tests

This commit is contained in:
dc-dc-dc 2023-12-18 17:04:32 -05:00
parent f09bcc7d50
commit 18a1c335d0

View File

@ -161,15 +161,13 @@ TEST_CASE("test jsonDeserialize") {
} }
TEST_CASE("test save_safetensor") { TEST_CASE("test save_safetensor") {
std::string file_path = get_temp_file("test_arr.safetensors");
auto map = std::unordered_map<std::string, array>(); auto map = std::unordered_map<std::string, array>();
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
map.insert({"test2", ones({2, 2})}); map.insert({"test2", ones({2, 2})});
MESSAGE("SAVING"); MESSAGE("SAVING");
save_safetensor("../../temp1", map); save_safetensor(file_path, map);
} auto safeDict = load_safetensor(file_path);
TEST_CASE("test load_safetensor") {
auto safeDict = load_safetensor("../../temp1.safetensors");
CHECK_EQ(safeDict.size(), 2); CHECK_EQ(safeDict.size(), 2);
CHECK_EQ(safeDict.count("test"), 1); CHECK_EQ(safeDict.count("test"), 1);
CHECK_EQ(safeDict.count("test2"), 1); CHECK_EQ(safeDict.count("test2"), 1);