implemented Flatten Module (#149)

* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
__mo_san__
2023-12-17 06:54:37 +01:00
committed by GitHub
parent eebd7c275d
commit 52e1589a52
8 changed files with 113 additions and 2 deletions

View File

@@ -58,6 +58,27 @@ TEST_CASE("test reshape") {
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
}
TEST_CASE("test flatten") {
array x = zeros({2, 3, 4});
CHECK_EQ(flatten(x).shape(), std::vector<int>({2 * 3 * 4}));
CHECK_EQ(flatten(x, 1, 1).shape(), std::vector<int>({2, 3, 4}));
CHECK_EQ(flatten(x, 1, 2).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, 3).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, -1).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, -2, -1).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({2 * 3 * 4}));
CHECK_EQ(flatten(x, -4, -1).shape(), std::vector<int>({2 * 3 * 4}));
// Check start > end throws
CHECK_THROWS(flatten(x, 2, 1));
// Check scalar flattens to 1D
x = array(1);
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
CHECK_EQ(flatten(x, 0, 0).shape(), std::vector<int>({1}));
}
TEST_CASE("test squeeze and expand") {
array x = zeros({2, 1, 2, 1, 2, 1});
CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2});