mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:38:07 +08:00
more tests
This commit is contained in:
parent
42baa095d1
commit
5b4155d4d0
@ -140,15 +140,18 @@ TEST_CASE("test parseJson") {
|
|||||||
|
|
||||||
TEST_CASE("test load_safetensor") {
|
TEST_CASE("test load_safetensor") {
|
||||||
auto safeDict = load_safetensor("../../temp.safe");
|
auto safeDict = load_safetensor("../../temp.safe");
|
||||||
CHECK_EQ(safeDict.size(), 1);
|
CHECK_EQ(safeDict.size(), 2);
|
||||||
CHECK_EQ(safeDict.count("test"), 1);
|
CHECK_EQ(safeDict.count("test"), 1);
|
||||||
|
CHECK_EQ(safeDict.count("test2"), 1);
|
||||||
array test = safeDict.at("test");
|
array test = safeDict.at("test");
|
||||||
CHECK_EQ(test.dtype(), float32);
|
CHECK_EQ(test.dtype(), float32);
|
||||||
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
||||||
array b = array({1.0, 2.0, 3.0, 4.0});
|
CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
|
||||||
MESSAGE("test: " << test);
|
array test2 = safeDict.at("test2");
|
||||||
MESSAGE("b: " << b);
|
MESSAGE("test2: " << test2);
|
||||||
CHECK(array_equal(test, b).item<bool>());
|
CHECK_EQ(test2.dtype(), float32);
|
||||||
|
CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
|
||||||
|
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test single array serialization") {
|
TEST_CASE("test single array serialization") {
|
||||||
|
Loading…
Reference in New Issue
Block a user