More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -395,7 +395,7 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{8, 4});
CHECK_EQ(out[2].shape(), Shape{8, 4});
out = split(x, std::vector<int>{});
out = split(x, Shape{});
CHECK_EQ(out.size(), 1);
CHECK_EQ(out[0].shape(), x.shape());
@@ -405,25 +405,25 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{4, 12});
CHECK_EQ(out[2].shape(), Shape{1, 12});
out = split(x, std::vector<int>{20});
out = split(x, Shape{20});
CHECK_EQ(out.size(), 2);
CHECK_EQ(out[0].shape(), Shape{8, 12});
CHECK_EQ(out[1].shape(), Shape{0, 12});
// Negative indices
out = split(x, std::vector<int>{-5});
out = split(x, Shape{-5});
CHECK_EQ(out[0].shape(), Shape{3, 12});
CHECK_EQ(out[1].shape(), Shape{5, 12});
// Different axis
out = split(x, std::vector<int>{2, 8}, 1);
out = split(x, {2, 8}, 1);
CHECK_EQ(out[0].shape(), Shape{8, 2});
CHECK_EQ(out[1].shape(), Shape{8, 6});
CHECK_EQ(out[2].shape(), Shape{8, 4});
// Out of order indices
x = arange(5);
out = split(x, std::vector<int>{2, 1, 2});
out = split(x, {2, 1, 2});
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
CHECK(array_equal(out[1], array({})).item<bool>());
CHECK(array_equal(out[2], array({1})).item<bool>());