mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
implemented Flatten Module (#149)
* implemented flatten op --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -58,6 +58,27 @@ TEST_CASE("test reshape") {
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
|
||||
}
|
||||
|
||||
TEST_CASE("test flatten") {
|
||||
array x = zeros({2, 3, 4});
|
||||
CHECK_EQ(flatten(x).shape(), std::vector<int>({2 * 3 * 4}));
|
||||
|
||||
CHECK_EQ(flatten(x, 1, 1).shape(), std::vector<int>({2, 3, 4}));
|
||||
CHECK_EQ(flatten(x, 1, 2).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, 1, 3).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, 1, -1).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -2, -1).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({2 * 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -4, -1).shape(), std::vector<int>({2 * 3 * 4}));
|
||||
|
||||
// Check start > end throws
|
||||
CHECK_THROWS(flatten(x, 2, 1));
|
||||
|
||||
// Check scalar flattens to 1D
|
||||
x = array(1);
|
||||
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
|
||||
CHECK_EQ(flatten(x, 0, 0).shape(), std::vector<int>({1}));
|
||||
}
|
||||
|
||||
TEST_CASE("test squeeze and expand") {
|
||||
array x = zeros({2, 1, 2, 1, 2, 1});
|
||||
CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2});
|
||||
|
Reference in New Issue
Block a user