mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 14:58:11 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -163,6 +163,23 @@ TEST_CASE("test flatten") {
|
||||
CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));
|
||||
}
|
||||
|
||||
TEST_CASE("test unflatten") {
|
||||
array x = array(1);
|
||||
CHECK_THROWS(unflatten(x, 0, {1, 1}));
|
||||
|
||||
x = array({1});
|
||||
auto out = unflatten(x, 0, {1, 1});
|
||||
CHECK_EQ(out.shape(), Shape({1, 1}));
|
||||
CHECK_THROWS(unflatten(x, 1, {1, 1}));
|
||||
CHECK_THROWS(unflatten(x, 0, {-1, -1}));
|
||||
CHECK_THROWS(unflatten(x, 0, {-1, 2}));
|
||||
CHECK_THROWS(unflatten(x, 0, {}));
|
||||
|
||||
x = zeros({4, 8});
|
||||
out = unflatten(x, 1, {2, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape({4, 2, 2, 2}));
|
||||
}
|
||||
|
||||
TEST_CASE("test squeeze and expand") {
|
||||
array x = zeros({2, 1, 2, 1, 2, 1});
|
||||
CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});
|
||||
|
||||
Reference in New Issue
Block a user