From 5b4155d4d056768e979dd54a37636649d9ada5d9 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Mon, 18 Dec 2023 14:52:59 -0500 Subject: [PATCH] more tests --- tests/load_tests.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 35ebe0533..c4fc554b2 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -140,15 +140,18 @@ TEST_CASE("test parseJson") { TEST_CASE("test load_safetensor") { auto safeDict = load_safetensor("../../temp.safe"); - CHECK_EQ(safeDict.size(), 1); + CHECK_EQ(safeDict.size(), 2); CHECK_EQ(safeDict.count("test"), 1); + CHECK_EQ(safeDict.count("test2"), 1); array test = safeDict.at("test"); CHECK_EQ(test.dtype(), float32); CHECK_EQ(test.shape(), std::vector({4})); - array b = array({1.0, 2.0, 3.0, 4.0}); - MESSAGE("test: " << test); - MESSAGE("b: " << b); - CHECK(array_equal(test, b).item()); + CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item()); + array test2 = safeDict.at("test2"); + MESSAGE("test2: " << test2); + CHECK_EQ(test2.dtype(), float32); + CHECK_EQ(test2.shape(), std::vector({2, 2})); + CHECK(array_equal(test2, ones({2, 2})).item()); } TEST_CASE("test single array serialization") {