Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -163,6 +163,23 @@ TEST_CASE("test flatten") {
CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));
}
TEST_CASE("test unflatten") {
array x = array(1);
CHECK_THROWS(unflatten(x, 0, {1, 1}));
x = array({1});
auto out = unflatten(x, 0, {1, 1});
CHECK_EQ(out.shape(), Shape({1, 1}));
CHECK_THROWS(unflatten(x, 1, {1, 1}));
CHECK_THROWS(unflatten(x, 0, {-1, -1}));
CHECK_THROWS(unflatten(x, 0, {-1, 2}));
CHECK_THROWS(unflatten(x, 0, {}));
x = zeros({4, 8});
out = unflatten(x, 1, {2, 2, 2});
CHECK_EQ(out.shape(), Shape({4, 2, 2, 2}));
}
TEST_CASE("test squeeze and expand") {
array x = zeros({2, 1, 2, 1, 2, 1});
CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});