diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 49f94eb1d..cac6999bd 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -161,15 +161,13 @@ TEST_CASE("test jsonDeserialize") { } TEST_CASE("test save_safetensor") { + std::string file_path = get_temp_file("test_arr.safetensors"); auto map = std::unordered_map(); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test2", ones({2, 2})}); MESSAGE("SAVING"); - save_safetensor("../../temp1", map); -} - -TEST_CASE("test load_safetensor") { - auto safeDict = load_safetensor("../../temp1.safetensors"); + save_safetensor(file_path, map); + auto safeDict = load_safetensor(file_path); CHECK_EQ(safeDict.size(), 2); CHECK_EQ(safeDict.count("test"), 1); CHECK_EQ(safeDict.count("test2"), 1);