more tests

This commit is contained in:
dc-dc-dc 2023-12-18 14:52:59 -05:00
parent 42baa095d1
commit 5b4155d4d0

View File

@ -140,15 +140,18 @@ TEST_CASE("test parseJson") {
TEST_CASE("test load_safetensor") { TEST_CASE("test load_safetensor") {
auto safeDict = load_safetensor("../../temp.safe"); auto safeDict = load_safetensor("../../temp.safe");
CHECK_EQ(safeDict.size(), 1); CHECK_EQ(safeDict.size(), 2);
CHECK_EQ(safeDict.count("test"), 1); CHECK_EQ(safeDict.count("test"), 1);
CHECK_EQ(safeDict.count("test2"), 1);
array test = safeDict.at("test"); array test = safeDict.at("test");
CHECK_EQ(test.dtype(), float32); CHECK_EQ(test.dtype(), float32);
CHECK_EQ(test.shape(), std::vector<int>({4})); CHECK_EQ(test.shape(), std::vector<int>({4}));
array b = array({1.0, 2.0, 3.0, 4.0}); CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
MESSAGE("test: " << test); array test2 = safeDict.at("test2");
MESSAGE("b: " << b); MESSAGE("test2: " << test2);
CHECK(array_equal(test, b).item<bool>()); CHECK_EQ(test2.dtype(), float32);
CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
} }
TEST_CASE("test single array serialization") { TEST_CASE("test single array serialization") {